/Tensorflow_Pandas_Numpy/source3.6/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops_pywrapper.py

https://github.com/ryfeus/lambda-packs · Python · 130 lines · 49 code · 9 blank · 72 comment · 2 complexity · 084a6d9d19399bff75fd408fdfde7634 MD5 · raw file

  1. """Python wrappers around TensorFlow ops.
  2. This file is MACHINE GENERATED! Do not edit.
  3. Original C++ source file: nearest_neighbor_ops_pywrapper.cc
  4. """
  5. import collections as _collections
  6. from tensorflow.python.eager import execute as _execute
  7. from tensorflow.python.eager import context as _context
  8. from tensorflow.python.eager import core as _core
  9. from tensorflow.python.framework import dtypes as _dtypes
  10. from tensorflow.python.framework import tensor_shape as _tensor_shape
  11. from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
  12. # Needed to trigger the call to _set_call_cpp_shape_fn.
  13. from tensorflow.python.framework import common_shapes as _common_shapes
  14. from tensorflow.python.framework import op_def_registry as _op_def_registry
  15. from tensorflow.python.framework import ops as _ops
  16. from tensorflow.python.framework import op_def_library as _op_def_library
  17. _hyperplane_lsh_probes_outputs = ["probes", "table_ids"]
  18. _HyperplaneLSHProbesOutput = _collections.namedtuple(
  19. "HyperplaneLSHProbes", _hyperplane_lsh_probes_outputs)
  20. def hyperplane_lsh_probes(point_hyperplane_product, num_tables, num_hyperplanes_per_table, num_probes, name=None):
  21. r"""Computes probes for the hyperplane hash.
  22. The op supports multiprobing, i.e., the number of requested probes can be
  23. larger than the number of tables. In that case, the same table can be probed
  24. multiple times.
  25. The first `num_tables` probes are always the primary hashes for each table.
  26. Args:
  27. point_hyperplane_product: A `Tensor`. Must be one of the following types: `float32`, `float64`.
  28. a matrix of inner products between the hyperplanes
  29. and the points to be hashed. These values should not be quantized so that we
  30. can correctly compute the probing sequence. The expected shape is
  31. `batch_size` times `num_tables * num_hyperplanes_per_table`, i.e., each
  32. element of the batch corresponds to one row of the matrix.
  33. num_tables: A `Tensor` of type `int32`.
  34. the number of tables to compute probes for.
  35. num_hyperplanes_per_table: A `Tensor` of type `int32`.
  36. the number of hyperplanes per table.
  37. num_probes: A `Tensor` of type `int32`.
  38. the requested number of probes per table.
  39. name: A name for the operation (optional).
  40. Returns:
  41. A tuple of `Tensor` objects (probes, table_ids).
  42. probes: A `Tensor` of type `int32`. the output matrix of probes. Size `batch_size` times `num_probes`.
  43. table_ids: A `Tensor` of type `int32`. the output matrix of tables ids. Size `batch_size` times `num_probes`.
  44. """
  45. _ctx = _context.context()
  46. if _ctx.in_graph_mode():
  47. _, _, _op = _op_def_lib._apply_op_helper(
  48. "HyperplaneLSHProbes",
  49. point_hyperplane_product=point_hyperplane_product,
  50. num_tables=num_tables,
  51. num_hyperplanes_per_table=num_hyperplanes_per_table,
  52. num_probes=num_probes, name=name)
  53. _result = _op.outputs[:]
  54. _inputs_flat = _op.inputs
  55. _attrs = ("CoordinateType", _op.get_attr("CoordinateType"))
  56. else:
  57. _attr_CoordinateType, (point_hyperplane_product,) = _execute.args_to_matching_eager([point_hyperplane_product], _ctx)
  58. _attr_CoordinateType = _attr_CoordinateType.as_datatype_enum
  59. num_tables = _ops.convert_to_tensor(num_tables, _dtypes.int32)
  60. num_hyperplanes_per_table = _ops.convert_to_tensor(num_hyperplanes_per_table, _dtypes.int32)
  61. num_probes = _ops.convert_to_tensor(num_probes, _dtypes.int32)
  62. _inputs_flat = [point_hyperplane_product, num_tables, num_hyperplanes_per_table, num_probes]
  63. _attrs = ("CoordinateType", _attr_CoordinateType)
  64. _result = _execute.execute(b"HyperplaneLSHProbes", 2, inputs=_inputs_flat,
  65. attrs=_attrs, ctx=_ctx, name=name)
  66. _execute.record_gradient(
  67. "HyperplaneLSHProbes", _inputs_flat, _attrs, _result, name)
  68. _result = _HyperplaneLSHProbesOutput._make(_result)
  69. return _result
  70. _ops.RegisterShape("HyperplaneLSHProbes")(None)
  71. def _InitOpDefLibrary(op_list_proto_bytes):
  72. op_list = _op_def_pb2.OpList()
  73. op_list.ParseFromString(op_list_proto_bytes)
  74. _op_def_registry.register_op_list(op_list)
  75. op_def_lib = _op_def_library.OpDefLibrary()
  76. op_def_lib.add_op_list(op_list)
  77. return op_def_lib
  78. # op {
  79. # name: "HyperplaneLSHProbes"
  80. # input_arg {
  81. # name: "point_hyperplane_product"
  82. # type_attr: "CoordinateType"
  83. # }
  84. # input_arg {
  85. # name: "num_tables"
  86. # type: DT_INT32
  87. # }
  88. # input_arg {
  89. # name: "num_hyperplanes_per_table"
  90. # type: DT_INT32
  91. # }
  92. # input_arg {
  93. # name: "num_probes"
  94. # type: DT_INT32
  95. # }
  96. # output_arg {
  97. # name: "probes"
  98. # type: DT_INT32
  99. # }
  100. # output_arg {
  101. # name: "table_ids"
  102. # type: DT_INT32
  103. # }
  104. # attr {
  105. # name: "CoordinateType"
  106. # type: "type"
  107. # allowed_values {
  108. # list {
  109. # type: DT_FLOAT
  110. # type: DT_DOUBLE
  111. # }
  112. # }
  113. # }
  114. # }
  115. _op_def_lib = _InitOpDefLibrary(b"\n\273\001\n\023HyperplaneLSHProbes\022*\n\030point_hyperplane_product\"\016CoordinateType\022\016\n\nnum_tables\030\003\022\035\n\031num_hyperplanes_per_table\030\003\022\016\n\nnum_probes\030\003\032\n\n\006probes\030\003\032\r\n\ttable_ids\030\003\"\036\n\016CoordinateType\022\004type:\006\n\0042\002\001\002")