faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -9,8 +9,6 @@
9
9
 
10
10
  #pragma once
11
11
 
12
-
13
-
14
12
  #include <faiss/Index.h>
15
13
  #include <faiss/VectorTransform.h>
16
14
 
@@ -18,21 +16,20 @@ namespace faiss {
18
16
 
19
17
  /** Index that applies a LinearTransform transform on vectors before
20
18
  * handing them over to a sub-index */
21
- struct IndexPreTransform: Index {
19
+ struct IndexPreTransform : Index {
20
+ std::vector<VectorTransform*> chain; ///! chain of tranforms
21
+ Index* index; ///! the sub-index
22
22
 
23
- std::vector<VectorTransform *> chain; ///! chain of tranforms
24
- Index * index; ///! the sub-index
23
+ bool own_fields; ///! whether pointers are deleted in destructor
25
24
 
26
- bool own_fields; ///! whether pointers are deleted in destructor
25
+ explicit IndexPreTransform(Index* index);
27
26
 
28
- explicit IndexPreTransform (Index *index);
29
-
30
- IndexPreTransform ();
27
+ IndexPreTransform();
31
28
 
32
29
  /// ltrans is the last transform before the index
33
- IndexPreTransform (VectorTransform * ltrans, Index * index);
30
+ IndexPreTransform(VectorTransform* ltrans, Index* index);
34
31
 
35
- void prepend_transform (VectorTransform * ltrans);
32
+ void prepend_transform(VectorTransform* ltrans);
36
33
 
37
34
  void train(idx_t n, const float* x) override;
38
35
 
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
47
44
  size_t remove_ids(const IDSelector& sel) override;
48
45
 
49
46
  void search(
50
- idx_t n,
51
- const float* x,
52
- idx_t k,
53
- float* distances,
54
- idx_t* labels) const override;
55
-
47
+ idx_t n,
48
+ const float* x,
49
+ idx_t k,
50
+ float* distances,
51
+ idx_t* labels) const override;
56
52
 
57
53
  /* range search, no attempt is done to change the radius */
58
- void range_search (idx_t n, const float* x, float radius,
59
- RangeSearchResult* result) const override;
54
+ void range_search(
55
+ idx_t n,
56
+ const float* x,
57
+ float radius,
58
+ RangeSearchResult* result) const override;
60
59
 
60
+ void reconstruct(idx_t key, float* recons) const override;
61
61
 
62
- void reconstruct (idx_t key, float * recons) const override;
62
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
63
 
64
- void reconstruct_n (idx_t i0, idx_t ni, float *recons)
65
- const override;
66
-
67
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
68
- float *distances, idx_t *labels,
69
- float *recons) const override;
64
+ void search_and_reconstruct(
65
+ idx_t n,
66
+ const float* x,
67
+ idx_t k,
68
+ float* distances,
69
+ idx_t* labels,
70
+ float* recons) const override;
70
71
 
71
72
  /// apply the transforms in the chain. The returned float * may be
72
73
  /// equal to x, otherwise it should be deallocated.
73
- const float * apply_chain (idx_t n, const float *x) const;
74
+ const float* apply_chain(idx_t n, const float* x) const;
74
75
 
75
76
  /// Reverse the transforms in the chain. May not be implemented for
76
77
  /// all transforms in the chain or may return approximate results.
77
- void reverse_chain (idx_t n, const float* xt, float* x) const;
78
-
78
+ void reverse_chain(idx_t n, const float* xt, float* x) const;
79
79
 
80
- DistanceComputer * get_distance_computer() const override;
80
+ DistanceComputer* get_distance_computer() const override;
81
81
 
82
82
  /* standalone codec interface */
83
- size_t sa_code_size () const override;
84
- void sa_encode (idx_t n, const float *x,
85
- uint8_t *bytes) const override;
86
- void sa_decode (idx_t n, const uint8_t *bytes,
87
- float *x) const override;
83
+ size_t sa_code_size() const override;
84
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
88
86
 
89
87
  ~IndexPreTransform() override;
90
88
  };
91
89
 
92
-
93
90
  } // namespace faiss
@@ -5,63 +5,58 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexRefine.h>
10
9
 
10
+ #include <faiss/IndexFlat.h>
11
+ #include <faiss/impl/AuxIndexStructures.h>
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/utils/Heap.h>
11
14
  #include <faiss/utils/distances.h>
12
15
  #include <faiss/utils/utils.h>
13
- #include <faiss/utils/Heap.h>
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/AuxIndexStructures.h>
16
- #include <faiss/IndexFlat.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
-
21
-
22
19
  /***************************************************
23
20
  * IndexRefine
24
21
  ***************************************************/
25
22
 
26
- IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
- Index (base_index->d, base_index->metric_type),
28
- base_index (base_index),
29
- refine_index (refine_index)
30
- {
23
+ IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
24
+ : Index(base_index->d, base_index->metric_type),
25
+ base_index(base_index),
26
+ refine_index(refine_index) {
31
27
  own_fields = own_refine_index = false;
32
28
  if (refine_index != nullptr) {
33
- FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
- FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
29
+ FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
30
+ FAISS_THROW_IF_NOT(
31
+ base_index->metric_type == refine_index->metric_type);
35
32
  is_trained = base_index->is_trained && refine_index->is_trained;
36
- FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
33
+ FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
37
34
  } // other case is useful only to construct an IndexRefineFlat
38
35
  ntotal = base_index->ntotal;
39
36
  }
40
37
 
41
- IndexRefine::IndexRefine ():
42
- base_index(nullptr), refine_index(nullptr),
43
- own_fields(false), own_refine_index(false)
44
- {
45
- }
38
+ IndexRefine::IndexRefine()
39
+ : base_index(nullptr),
40
+ refine_index(nullptr),
41
+ own_fields(false),
42
+ own_refine_index(false) {}
46
43
 
47
- void IndexRefine::train (idx_t n, const float *x)
48
- {
49
- base_index->train (n, x);
50
- refine_index->train (n, x);
44
+ void IndexRefine::train(idx_t n, const float* x) {
45
+ base_index->train(n, x);
46
+ refine_index->train(n, x);
51
47
  is_trained = true;
52
48
  }
53
49
 
54
- void IndexRefine::add (idx_t n, const float *x) {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- base_index->add (n, x);
57
- refine_index->add (n, x);
50
+ void IndexRefine::add(idx_t n, const float* x) {
51
+ FAISS_THROW_IF_NOT(is_trained);
52
+ base_index->add(n, x);
53
+ refine_index->add(n, x);
58
54
  ntotal = refine_index->ntotal;
59
55
  }
60
56
 
61
- void IndexRefine::reset ()
62
- {
63
- base_index->reset ();
64
- refine_index->reset ();
57
+ void IndexRefine::reset() {
58
+ base_index->reset();
59
+ refine_index->reset();
65
60
  ntotal = 0;
66
61
  }
67
62
 
@@ -69,69 +64,72 @@ namespace {
69
64
 
70
65
  typedef faiss::Index::idx_t idx_t;
71
66
 
72
- template<class C>
73
- static void reorder_2_heaps (
74
- idx_t n,
75
- idx_t k, idx_t *labels, float *distances,
76
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
- {
67
+ template <class C>
68
+ static void reorder_2_heaps(
69
+ idx_t n,
70
+ idx_t k,
71
+ idx_t* labels,
72
+ float* distances,
73
+ idx_t k_base,
74
+ const idx_t* base_labels,
75
+ const float* base_distances) {
78
76
  #pragma omp parallel for
79
77
  for (idx_t i = 0; i < n; i++) {
80
- idx_t *idxo = labels + i * k;
81
- float *diso = distances + i * k;
82
- const idx_t *idxi = base_labels + i * k_base;
83
- const float *disi = base_distances + i * k_base;
78
+ idx_t* idxo = labels + i * k;
79
+ float* diso = distances + i * k;
80
+ const idx_t* idxi = base_labels + i * k_base;
81
+ const float* disi = base_distances + i * k_base;
84
82
 
85
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
83
+ heap_heapify<C>(k, diso, idxo, disi, idxi, k);
86
84
  if (k_base != k) { // add remaining elements
87
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
85
+ heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
88
86
  }
89
- heap_reorder<C> (k, diso, idxo);
87
+ heap_reorder<C>(k, diso, idxo);
90
88
  }
91
89
  }
92
90
 
93
-
94
91
  } // anonymous namespace
95
92
 
96
-
97
-
98
- void IndexRefine::search (
99
- idx_t n, const float *x, idx_t k,
100
- float *distances, idx_t *labels) const
101
- {
102
- FAISS_THROW_IF_NOT (is_trained);
103
- idx_t k_base = idx_t (k * k_factor);
104
- idx_t * base_labels = labels;
105
- float * base_distances = distances;
93
+ void IndexRefine::search(
94
+ idx_t n,
95
+ const float* x,
96
+ idx_t k,
97
+ float* distances,
98
+ idx_t* labels) const {
99
+ FAISS_THROW_IF_NOT(k > 0);
100
+
101
+ FAISS_THROW_IF_NOT(is_trained);
102
+ idx_t k_base = idx_t(k * k_factor);
103
+ idx_t* base_labels = labels;
104
+ float* base_distances = distances;
106
105
  ScopeDeleter<idx_t> del1;
107
106
  ScopeDeleter<float> del2;
108
107
 
109
108
  if (k != k_base) {
110
- base_labels = new idx_t [n * k_base];
111
- del1.set (base_labels);
112
- base_distances = new float [n * k_base];
113
- del2.set (base_distances);
109
+ base_labels = new idx_t[n * k_base];
110
+ del1.set(base_labels);
111
+ base_distances = new float[n * k_base];
112
+ del2.set(base_distances);
114
113
  }
115
114
 
116
- base_index->search (n, x, k_base, base_distances, base_labels);
115
+ base_index->search(n, x, k_base, base_distances, base_labels);
117
116
 
118
117
  for (int i = 0; i < n * k_base; i++)
119
- assert (base_labels[i] >= -1 &&
120
- base_labels[i] < ntotal);
118
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
121
119
 
122
- // parallelize over queries
120
+ // parallelize over queries
123
121
  #pragma omp parallel if (n > 1)
124
122
  {
125
123
  std::unique_ptr<DistanceComputer> dc(
126
- refine_index->get_distance_computer()
127
- );
124
+ refine_index->get_distance_computer());
128
125
  #pragma omp for
129
126
  for (idx_t i = 0; i < n; i++) {
130
127
  dc->set_query(x + i * d);
131
128
  idx_t ij = i * k_base;
132
129
  for (idx_t j = 0; j < k_base; j++) {
133
130
  idx_t idx = base_labels[ij];
134
- if (idx < 0) break;
131
+ if (idx < 0)
132
+ break;
135
133
  base_distances[ij] = (*dc)(idx);
136
134
  ij++;
137
135
  }
@@ -140,117 +138,103 @@ void IndexRefine::search (
140
138
 
141
139
  // sort and store result
142
140
  if (metric_type == METRIC_L2) {
143
- typedef CMax <float, idx_t> C;
144
- reorder_2_heaps<C> (
145
- n, k, labels, distances,
146
- k_base, base_labels, base_distances);
141
+ typedef CMax<float, idx_t> C;
142
+ reorder_2_heaps<C>(
143
+ n, k, labels, distances, k_base, base_labels, base_distances);
147
144
 
148
145
  } else if (metric_type == METRIC_INNER_PRODUCT) {
149
- typedef CMin <float, idx_t> C;
150
- reorder_2_heaps<C> (
151
- n, k, labels, distances,
152
- k_base, base_labels, base_distances);
146
+ typedef CMin<float, idx_t> C;
147
+ reorder_2_heaps<C>(
148
+ n, k, labels, distances, k_base, base_labels, base_distances);
153
149
  } else {
154
150
  FAISS_THROW_MSG("Metric type not supported");
155
151
  }
156
-
157
152
  }
158
153
 
159
- void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
- refine_index->reconstruct (key, recons);
154
+ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
+ refine_index->reconstruct(key, recons);
161
156
  }
162
157
 
163
-
164
-
165
-
166
- IndexRefine::~IndexRefine ()
167
- {
168
- if (own_fields) delete base_index;
169
- if (own_refine_index) delete refine_index;
158
+ IndexRefine::~IndexRefine() {
159
+ if (own_fields)
160
+ delete base_index;
161
+ if (own_refine_index)
162
+ delete refine_index;
170
163
  }
171
164
 
172
-
173
165
  /***************************************************
174
166
  * IndexRefineFlat
175
167
  ***************************************************/
176
168
 
177
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
- IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
- {
169
+ IndexRefineFlat::IndexRefineFlat(Index* base_index)
170
+ : IndexRefine(
171
+ base_index,
172
+ new IndexFlat(base_index->d, base_index->metric_type)) {
180
173
  is_trained = base_index->is_trained;
181
174
  own_refine_index = true;
182
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
- "base_index should be empty in the beginning");
175
+ FAISS_THROW_IF_NOT_MSG(
176
+ base_index->ntotal == 0,
177
+ "base_index should be empty in the beginning");
184
178
  }
185
179
 
186
-
187
- IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
- IndexRefine (base_index, nullptr)
189
- {
180
+ IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
181
+ : IndexRefine(base_index, nullptr) {
190
182
  is_trained = base_index->is_trained;
191
183
  refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
184
  own_refine_index = true;
193
- refine_index->add (base_index->ntotal, xb);
194
-
185
+ refine_index->add(base_index->ntotal, xb);
195
186
  }
196
187
 
197
- IndexRefineFlat::IndexRefineFlat():
198
- IndexRefine()
199
- {
188
+ IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
200
189
  own_refine_index = true;
201
190
  }
202
191
 
203
-
204
- void IndexRefineFlat::search (
205
- idx_t n, const float *x, idx_t k,
206
- float *distances, idx_t *labels) const
207
- {
208
- FAISS_THROW_IF_NOT (is_trained);
209
- idx_t k_base = idx_t (k * k_factor);
210
- idx_t * base_labels = labels;
211
- float * base_distances = distances;
192
+ void IndexRefineFlat::search(
193
+ idx_t n,
194
+ const float* x,
195
+ idx_t k,
196
+ float* distances,
197
+ idx_t* labels) const {
198
+ FAISS_THROW_IF_NOT(k > 0);
199
+
200
+ FAISS_THROW_IF_NOT(is_trained);
201
+ idx_t k_base = idx_t(k * k_factor);
202
+ idx_t* base_labels = labels;
203
+ float* base_distances = distances;
212
204
  ScopeDeleter<idx_t> del1;
213
205
  ScopeDeleter<float> del2;
214
206
 
215
207
  if (k != k_base) {
216
- base_labels = new idx_t [n * k_base];
217
- del1.set (base_labels);
218
- base_distances = new float [n * k_base];
219
- del2.set (base_distances);
208
+ base_labels = new idx_t[n * k_base];
209
+ del1.set(base_labels);
210
+ base_distances = new float[n * k_base];
211
+ del2.set(base_distances);
220
212
  }
221
213
 
222
- base_index->search (n, x, k_base, base_distances, base_labels);
214
+ base_index->search(n, x, k_base, base_distances, base_labels);
223
215
 
224
216
  for (int i = 0; i < n * k_base; i++)
225
- assert (base_labels[i] >= -1 &&
226
- base_labels[i] < ntotal);
217
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
227
218
 
228
219
  // compute refined distances
229
- auto rf = dynamic_cast<const IndexFlat *>(refine_index);
220
+ auto rf = dynamic_cast<const IndexFlat*>(refine_index);
230
221
  FAISS_THROW_IF_NOT(rf);
231
222
 
232
- rf->compute_distance_subset (
233
- n, x, k_base, base_distances, base_labels);
223
+ rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
234
224
 
235
225
  // sort and store result
236
226
  if (metric_type == METRIC_L2) {
237
- typedef CMax <float, idx_t> C;
238
- reorder_2_heaps<C> (
239
- n, k, labels, distances,
240
- k_base, base_labels, base_distances);
227
+ typedef CMax<float, idx_t> C;
228
+ reorder_2_heaps<C>(
229
+ n, k, labels, distances, k_base, base_labels, base_distances);
241
230
 
242
231
  } else if (metric_type == METRIC_INNER_PRODUCT) {
243
- typedef CMin <float, idx_t> C;
244
- reorder_2_heaps<C> (
245
- n, k, labels, distances,
246
- k_base, base_labels, base_distances);
232
+ typedef CMin<float, idx_t> C;
233
+ reorder_2_heaps<C>(
234
+ n, k, labels, distances, k_base, base_labels, base_distances);
247
235
  } else {
248
236
  FAISS_THROW_MSG("Metric type not supported");
249
237
  }
250
-
251
238
  }
252
239
 
253
-
254
-
255
-
256
240
  } // namespace faiss