faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -9,8 +9,6 @@
9
9
 
10
10
  #pragma once
11
11
 
12
-
13
-
14
12
  #include <faiss/Index.h>
15
13
  #include <faiss/VectorTransform.h>
16
14
 
@@ -18,21 +16,20 @@ namespace faiss {
18
16
 
19
17
  /** Index that applies a LinearTransform transform on vectors before
20
18
  * handing them over to a sub-index */
21
- struct IndexPreTransform: Index {
19
+ struct IndexPreTransform : Index {
20
+ std::vector<VectorTransform*> chain; ///! chain of tranforms
21
+ Index* index; ///! the sub-index
22
22
 
23
- std::vector<VectorTransform *> chain; ///! chain of tranforms
24
- Index * index; ///! the sub-index
23
+ bool own_fields; ///! whether pointers are deleted in destructor
25
24
 
26
- bool own_fields; ///! whether pointers are deleted in destructor
25
+ explicit IndexPreTransform(Index* index);
27
26
 
28
- explicit IndexPreTransform (Index *index);
29
-
30
- IndexPreTransform ();
27
+ IndexPreTransform();
31
28
 
32
29
  /// ltrans is the last transform before the index
33
- IndexPreTransform (VectorTransform * ltrans, Index * index);
30
+ IndexPreTransform(VectorTransform* ltrans, Index* index);
34
31
 
35
- void prepend_transform (VectorTransform * ltrans);
32
+ void prepend_transform(VectorTransform* ltrans);
36
33
 
37
34
  void train(idx_t n, const float* x) override;
38
35
 
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
47
44
  size_t remove_ids(const IDSelector& sel) override;
48
45
 
49
46
  void search(
50
- idx_t n,
51
- const float* x,
52
- idx_t k,
53
- float* distances,
54
- idx_t* labels) const override;
55
-
47
+ idx_t n,
48
+ const float* x,
49
+ idx_t k,
50
+ float* distances,
51
+ idx_t* labels) const override;
56
52
 
57
53
  /* range search, no attempt is done to change the radius */
58
- void range_search (idx_t n, const float* x, float radius,
59
- RangeSearchResult* result) const override;
54
+ void range_search(
55
+ idx_t n,
56
+ const float* x,
57
+ float radius,
58
+ RangeSearchResult* result) const override;
60
59
 
60
+ void reconstruct(idx_t key, float* recons) const override;
61
61
 
62
- void reconstruct (idx_t key, float * recons) const override;
62
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
63
 
64
- void reconstruct_n (idx_t i0, idx_t ni, float *recons)
65
- const override;
66
-
67
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
68
- float *distances, idx_t *labels,
69
- float *recons) const override;
64
+ void search_and_reconstruct(
65
+ idx_t n,
66
+ const float* x,
67
+ idx_t k,
68
+ float* distances,
69
+ idx_t* labels,
70
+ float* recons) const override;
70
71
 
71
72
  /// apply the transforms in the chain. The returned float * may be
72
73
  /// equal to x, otherwise it should be deallocated.
73
- const float * apply_chain (idx_t n, const float *x) const;
74
+ const float* apply_chain(idx_t n, const float* x) const;
74
75
 
75
76
  /// Reverse the transforms in the chain. May not be implemented for
76
77
  /// all transforms in the chain or may return approximate results.
77
- void reverse_chain (idx_t n, const float* xt, float* x) const;
78
-
78
+ void reverse_chain(idx_t n, const float* xt, float* x) const;
79
79
 
80
- DistanceComputer * get_distance_computer() const override;
80
+ DistanceComputer* get_distance_computer() const override;
81
81
 
82
82
  /* standalone codec interface */
83
- size_t sa_code_size () const override;
84
- void sa_encode (idx_t n, const float *x,
85
- uint8_t *bytes) const override;
86
- void sa_decode (idx_t n, const uint8_t *bytes,
87
- float *x) const override;
83
+ size_t sa_code_size() const override;
84
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
88
86
 
89
87
  ~IndexPreTransform() override;
90
88
  };
91
89
 
92
-
93
90
  } // namespace faiss
@@ -5,63 +5,58 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexRefine.h>
10
9
 
10
+ #include <faiss/IndexFlat.h>
11
+ #include <faiss/impl/AuxIndexStructures.h>
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/utils/Heap.h>
11
14
  #include <faiss/utils/distances.h>
12
15
  #include <faiss/utils/utils.h>
13
- #include <faiss/utils/Heap.h>
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/AuxIndexStructures.h>
16
- #include <faiss/IndexFlat.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
-
21
-
22
19
  /***************************************************
23
20
  * IndexRefine
24
21
  ***************************************************/
25
22
 
26
- IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
- Index (base_index->d, base_index->metric_type),
28
- base_index (base_index),
29
- refine_index (refine_index)
30
- {
23
+ IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
24
+ : Index(base_index->d, base_index->metric_type),
25
+ base_index(base_index),
26
+ refine_index(refine_index) {
31
27
  own_fields = own_refine_index = false;
32
28
  if (refine_index != nullptr) {
33
- FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
- FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
29
+ FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
30
+ FAISS_THROW_IF_NOT(
31
+ base_index->metric_type == refine_index->metric_type);
35
32
  is_trained = base_index->is_trained && refine_index->is_trained;
36
- FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
33
+ FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
37
34
  } // other case is useful only to construct an IndexRefineFlat
38
35
  ntotal = base_index->ntotal;
39
36
  }
40
37
 
41
- IndexRefine::IndexRefine ():
42
- base_index(nullptr), refine_index(nullptr),
43
- own_fields(false), own_refine_index(false)
44
- {
45
- }
38
+ IndexRefine::IndexRefine()
39
+ : base_index(nullptr),
40
+ refine_index(nullptr),
41
+ own_fields(false),
42
+ own_refine_index(false) {}
46
43
 
47
- void IndexRefine::train (idx_t n, const float *x)
48
- {
49
- base_index->train (n, x);
50
- refine_index->train (n, x);
44
+ void IndexRefine::train(idx_t n, const float* x) {
45
+ base_index->train(n, x);
46
+ refine_index->train(n, x);
51
47
  is_trained = true;
52
48
  }
53
49
 
54
- void IndexRefine::add (idx_t n, const float *x) {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- base_index->add (n, x);
57
- refine_index->add (n, x);
50
+ void IndexRefine::add(idx_t n, const float* x) {
51
+ FAISS_THROW_IF_NOT(is_trained);
52
+ base_index->add(n, x);
53
+ refine_index->add(n, x);
58
54
  ntotal = refine_index->ntotal;
59
55
  }
60
56
 
61
- void IndexRefine::reset ()
62
- {
63
- base_index->reset ();
64
- refine_index->reset ();
57
+ void IndexRefine::reset() {
58
+ base_index->reset();
59
+ refine_index->reset();
65
60
  ntotal = 0;
66
61
  }
67
62
 
@@ -69,69 +64,72 @@ namespace {
69
64
 
70
65
  typedef faiss::Index::idx_t idx_t;
71
66
 
72
- template<class C>
73
- static void reorder_2_heaps (
74
- idx_t n,
75
- idx_t k, idx_t *labels, float *distances,
76
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
- {
67
+ template <class C>
68
+ static void reorder_2_heaps(
69
+ idx_t n,
70
+ idx_t k,
71
+ idx_t* labels,
72
+ float* distances,
73
+ idx_t k_base,
74
+ const idx_t* base_labels,
75
+ const float* base_distances) {
78
76
  #pragma omp parallel for
79
77
  for (idx_t i = 0; i < n; i++) {
80
- idx_t *idxo = labels + i * k;
81
- float *diso = distances + i * k;
82
- const idx_t *idxi = base_labels + i * k_base;
83
- const float *disi = base_distances + i * k_base;
78
+ idx_t* idxo = labels + i * k;
79
+ float* diso = distances + i * k;
80
+ const idx_t* idxi = base_labels + i * k_base;
81
+ const float* disi = base_distances + i * k_base;
84
82
 
85
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
83
+ heap_heapify<C>(k, diso, idxo, disi, idxi, k);
86
84
  if (k_base != k) { // add remaining elements
87
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
85
+ heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
88
86
  }
89
- heap_reorder<C> (k, diso, idxo);
87
+ heap_reorder<C>(k, diso, idxo);
90
88
  }
91
89
  }
92
90
 
93
-
94
91
  } // anonymous namespace
95
92
 
96
-
97
-
98
- void IndexRefine::search (
99
- idx_t n, const float *x, idx_t k,
100
- float *distances, idx_t *labels) const
101
- {
102
- FAISS_THROW_IF_NOT (is_trained);
103
- idx_t k_base = idx_t (k * k_factor);
104
- idx_t * base_labels = labels;
105
- float * base_distances = distances;
93
+ void IndexRefine::search(
94
+ idx_t n,
95
+ const float* x,
96
+ idx_t k,
97
+ float* distances,
98
+ idx_t* labels) const {
99
+ FAISS_THROW_IF_NOT(k > 0);
100
+
101
+ FAISS_THROW_IF_NOT(is_trained);
102
+ idx_t k_base = idx_t(k * k_factor);
103
+ idx_t* base_labels = labels;
104
+ float* base_distances = distances;
106
105
  ScopeDeleter<idx_t> del1;
107
106
  ScopeDeleter<float> del2;
108
107
 
109
108
  if (k != k_base) {
110
- base_labels = new idx_t [n * k_base];
111
- del1.set (base_labels);
112
- base_distances = new float [n * k_base];
113
- del2.set (base_distances);
109
+ base_labels = new idx_t[n * k_base];
110
+ del1.set(base_labels);
111
+ base_distances = new float[n * k_base];
112
+ del2.set(base_distances);
114
113
  }
115
114
 
116
- base_index->search (n, x, k_base, base_distances, base_labels);
115
+ base_index->search(n, x, k_base, base_distances, base_labels);
117
116
 
118
117
  for (int i = 0; i < n * k_base; i++)
119
- assert (base_labels[i] >= -1 &&
120
- base_labels[i] < ntotal);
118
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
121
119
 
122
- // parallelize over queries
120
+ // parallelize over queries
123
121
  #pragma omp parallel if (n > 1)
124
122
  {
125
123
  std::unique_ptr<DistanceComputer> dc(
126
- refine_index->get_distance_computer()
127
- );
124
+ refine_index->get_distance_computer());
128
125
  #pragma omp for
129
126
  for (idx_t i = 0; i < n; i++) {
130
127
  dc->set_query(x + i * d);
131
128
  idx_t ij = i * k_base;
132
129
  for (idx_t j = 0; j < k_base; j++) {
133
130
  idx_t idx = base_labels[ij];
134
- if (idx < 0) break;
131
+ if (idx < 0)
132
+ break;
135
133
  base_distances[ij] = (*dc)(idx);
136
134
  ij++;
137
135
  }
@@ -140,117 +138,103 @@ void IndexRefine::search (
140
138
 
141
139
  // sort and store result
142
140
  if (metric_type == METRIC_L2) {
143
- typedef CMax <float, idx_t> C;
144
- reorder_2_heaps<C> (
145
- n, k, labels, distances,
146
- k_base, base_labels, base_distances);
141
+ typedef CMax<float, idx_t> C;
142
+ reorder_2_heaps<C>(
143
+ n, k, labels, distances, k_base, base_labels, base_distances);
147
144
 
148
145
  } else if (metric_type == METRIC_INNER_PRODUCT) {
149
- typedef CMin <float, idx_t> C;
150
- reorder_2_heaps<C> (
151
- n, k, labels, distances,
152
- k_base, base_labels, base_distances);
146
+ typedef CMin<float, idx_t> C;
147
+ reorder_2_heaps<C>(
148
+ n, k, labels, distances, k_base, base_labels, base_distances);
153
149
  } else {
154
150
  FAISS_THROW_MSG("Metric type not supported");
155
151
  }
156
-
157
152
  }
158
153
 
159
- void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
- refine_index->reconstruct (key, recons);
154
+ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
+ refine_index->reconstruct(key, recons);
161
156
  }
162
157
 
163
-
164
-
165
-
166
- IndexRefine::~IndexRefine ()
167
- {
168
- if (own_fields) delete base_index;
169
- if (own_refine_index) delete refine_index;
158
+ IndexRefine::~IndexRefine() {
159
+ if (own_fields)
160
+ delete base_index;
161
+ if (own_refine_index)
162
+ delete refine_index;
170
163
  }
171
164
 
172
-
173
165
  /***************************************************
174
166
  * IndexRefineFlat
175
167
  ***************************************************/
176
168
 
177
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
- IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
- {
169
+ IndexRefineFlat::IndexRefineFlat(Index* base_index)
170
+ : IndexRefine(
171
+ base_index,
172
+ new IndexFlat(base_index->d, base_index->metric_type)) {
180
173
  is_trained = base_index->is_trained;
181
174
  own_refine_index = true;
182
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
- "base_index should be empty in the beginning");
175
+ FAISS_THROW_IF_NOT_MSG(
176
+ base_index->ntotal == 0,
177
+ "base_index should be empty in the beginning");
184
178
  }
185
179
 
186
-
187
- IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
- IndexRefine (base_index, nullptr)
189
- {
180
+ IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
181
+ : IndexRefine(base_index, nullptr) {
190
182
  is_trained = base_index->is_trained;
191
183
  refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
184
  own_refine_index = true;
193
- refine_index->add (base_index->ntotal, xb);
194
-
185
+ refine_index->add(base_index->ntotal, xb);
195
186
  }
196
187
 
197
- IndexRefineFlat::IndexRefineFlat():
198
- IndexRefine()
199
- {
188
+ IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
200
189
  own_refine_index = true;
201
190
  }
202
191
 
203
-
204
- void IndexRefineFlat::search (
205
- idx_t n, const float *x, idx_t k,
206
- float *distances, idx_t *labels) const
207
- {
208
- FAISS_THROW_IF_NOT (is_trained);
209
- idx_t k_base = idx_t (k * k_factor);
210
- idx_t * base_labels = labels;
211
- float * base_distances = distances;
192
+ void IndexRefineFlat::search(
193
+ idx_t n,
194
+ const float* x,
195
+ idx_t k,
196
+ float* distances,
197
+ idx_t* labels) const {
198
+ FAISS_THROW_IF_NOT(k > 0);
199
+
200
+ FAISS_THROW_IF_NOT(is_trained);
201
+ idx_t k_base = idx_t(k * k_factor);
202
+ idx_t* base_labels = labels;
203
+ float* base_distances = distances;
212
204
  ScopeDeleter<idx_t> del1;
213
205
  ScopeDeleter<float> del2;
214
206
 
215
207
  if (k != k_base) {
216
- base_labels = new idx_t [n * k_base];
217
- del1.set (base_labels);
218
- base_distances = new float [n * k_base];
219
- del2.set (base_distances);
208
+ base_labels = new idx_t[n * k_base];
209
+ del1.set(base_labels);
210
+ base_distances = new float[n * k_base];
211
+ del2.set(base_distances);
220
212
  }
221
213
 
222
- base_index->search (n, x, k_base, base_distances, base_labels);
214
+ base_index->search(n, x, k_base, base_distances, base_labels);
223
215
 
224
216
  for (int i = 0; i < n * k_base; i++)
225
- assert (base_labels[i] >= -1 &&
226
- base_labels[i] < ntotal);
217
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
227
218
 
228
219
  // compute refined distances
229
- auto rf = dynamic_cast<const IndexFlat *>(refine_index);
220
+ auto rf = dynamic_cast<const IndexFlat*>(refine_index);
230
221
  FAISS_THROW_IF_NOT(rf);
231
222
 
232
- rf->compute_distance_subset (
233
- n, x, k_base, base_distances, base_labels);
223
+ rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
234
224
 
235
225
  // sort and store result
236
226
  if (metric_type == METRIC_L2) {
237
- typedef CMax <float, idx_t> C;
238
- reorder_2_heaps<C> (
239
- n, k, labels, distances,
240
- k_base, base_labels, base_distances);
227
+ typedef CMax<float, idx_t> C;
228
+ reorder_2_heaps<C>(
229
+ n, k, labels, distances, k_base, base_labels, base_distances);
241
230
 
242
231
  } else if (metric_type == METRIC_INNER_PRODUCT) {
243
- typedef CMin <float, idx_t> C;
244
- reorder_2_heaps<C> (
245
- n, k, labels, distances,
246
- k_base, base_labels, base_distances);
232
+ typedef CMin<float, idx_t> C;
233
+ reorder_2_heaps<C>(
234
+ n, k, labels, distances, k_base, base_labels, base_distances);
247
235
  } else {
248
236
  FAISS_THROW_MSG("Metric type not supported");
249
237
  }
250
-
251
238
  }
252
239
 
253
-
254
-
255
-
256
240
  } // namespace faiss
@@ -9,32 +9,29 @@
9
9
 
10
10
  #include <faiss/Index.h>
11
11
 
12
-
13
12
  namespace faiss {
14
13
 
15
-
16
14
  /** Index that queries in a base_index (a fast one) and refines the
17
15
  * results with an exact search, hopefully improving the results.
18
16
  */
19
- struct IndexRefine: Index {
20
-
17
+ struct IndexRefine : Index {
21
18
  /// faster index to pre-select the vectors that should be filtered
22
- Index *base_index;
19
+ Index* base_index;
23
20
 
24
21
  /// refinement index
25
- Index *refine_index;
22
+ Index* refine_index;
26
23
 
27
- bool own_fields; ///< should the base index be deallocated?
28
- bool own_refine_index; ///< same with the refinement index
24
+ bool own_fields; ///< should the base index be deallocated?
25
+ bool own_refine_index; ///< same with the refinement index
29
26
 
30
27
  /// factor between k requested in search and the k requested from
31
28
  /// the base_index (should be >= 1)
32
29
  float k_factor = 1;
33
30
 
34
- /// intitialize from empty index
35
- IndexRefine (Index *base_index, Index *refine_index);
31
+ /// initialize from empty index
32
+ IndexRefine(Index* base_index, Index* refine_index);
36
33
 
37
- IndexRefine ();
34
+ IndexRefine();
38
35
 
39
36
  void train(idx_t n, const float* x) override;
40
37
 
@@ -43,31 +40,33 @@ struct IndexRefine: Index {
43
40
  void reset() override;
44
41
 
45
42
  void search(
46
- idx_t n, const float* x, idx_t k,
47
- float* distances, idx_t* labels) const override;
43
+ idx_t n,
44
+ const float* x,
45
+ idx_t k,
46
+ float* distances,
47
+ idx_t* labels) const override;
48
48
 
49
49
  // reconstruct is routed to the refine_index
50
- void reconstruct (idx_t key, float * recons) const override;
50
+ void reconstruct(idx_t key, float* recons) const override;
51
51
 
52
52
  ~IndexRefine() override;
53
53
  };
54
54
 
55
-
56
55
  /** Version where the refinement index is an IndexFlat. It has one additional
57
56
  * constructor that takes a table of elements to add to the flat refinement
58
57
  * index */
59
- struct IndexRefineFlat: IndexRefine {
60
- explicit IndexRefineFlat (Index *base_index);
61
- IndexRefineFlat(Index *base_index, const float *xb);
58
+ struct IndexRefineFlat : IndexRefine {
59
+ explicit IndexRefineFlat(Index* base_index);
60
+ IndexRefineFlat(Index* base_index, const float* xb);
62
61
 
63
62
  IndexRefineFlat();
64
63
 
65
64
  void search(
66
- idx_t n, const float* x, idx_t k,
67
- float* distances, idx_t* labels) const override;
68
-
65
+ idx_t n,
66
+ const float* x,
67
+ idx_t k,
68
+ float* distances,
69
+ idx_t* labels) const override;
69
70
  };
70
71
 
71
-
72
-
73
72
  } // namespace faiss