faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -9,8 +9,6 @@
9
9
 
10
10
  #pragma once
11
11
 
12
-
13
-
14
12
  #include <faiss/Index.h>
15
13
  #include <faiss/VectorTransform.h>
16
14
 
@@ -18,21 +16,20 @@ namespace faiss {
18
16
 
19
17
  /** Index that applies a LinearTransform transform on vectors before
20
18
  * handing them over to a sub-index */
21
- struct IndexPreTransform: Index {
19
+ struct IndexPreTransform : Index {
20
+ std::vector<VectorTransform*> chain; ///! chain of tranforms
21
+ Index* index; ///! the sub-index
22
22
 
23
- std::vector<VectorTransform *> chain; ///! chain of tranforms
24
- Index * index; ///! the sub-index
23
+ bool own_fields; ///! whether pointers are deleted in destructor
25
24
 
26
- bool own_fields; ///! whether pointers are deleted in destructor
25
+ explicit IndexPreTransform(Index* index);
27
26
 
28
- explicit IndexPreTransform (Index *index);
29
-
30
- IndexPreTransform ();
27
+ IndexPreTransform();
31
28
 
32
29
  /// ltrans is the last transform before the index
33
- IndexPreTransform (VectorTransform * ltrans, Index * index);
30
+ IndexPreTransform(VectorTransform* ltrans, Index* index);
34
31
 
35
- void prepend_transform (VectorTransform * ltrans);
32
+ void prepend_transform(VectorTransform* ltrans);
36
33
 
37
34
  void train(idx_t n, const float* x) override;
38
35
 
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
47
44
  size_t remove_ids(const IDSelector& sel) override;
48
45
 
49
46
  void search(
50
- idx_t n,
51
- const float* x,
52
- idx_t k,
53
- float* distances,
54
- idx_t* labels) const override;
55
-
47
+ idx_t n,
48
+ const float* x,
49
+ idx_t k,
50
+ float* distances,
51
+ idx_t* labels) const override;
56
52
 
57
53
  /* range search, no attempt is done to change the radius */
58
- void range_search (idx_t n, const float* x, float radius,
59
- RangeSearchResult* result) const override;
54
+ void range_search(
55
+ idx_t n,
56
+ const float* x,
57
+ float radius,
58
+ RangeSearchResult* result) const override;
60
59
 
60
+ void reconstruct(idx_t key, float* recons) const override;
61
61
 
62
- void reconstruct (idx_t key, float * recons) const override;
62
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
63
 
64
- void reconstruct_n (idx_t i0, idx_t ni, float *recons)
65
- const override;
66
-
67
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
68
- float *distances, idx_t *labels,
69
- float *recons) const override;
64
+ void search_and_reconstruct(
65
+ idx_t n,
66
+ const float* x,
67
+ idx_t k,
68
+ float* distances,
69
+ idx_t* labels,
70
+ float* recons) const override;
70
71
 
71
72
  /// apply the transforms in the chain. The returned float * may be
72
73
  /// equal to x, otherwise it should be deallocated.
73
- const float * apply_chain (idx_t n, const float *x) const;
74
+ const float* apply_chain(idx_t n, const float* x) const;
74
75
 
75
76
  /// Reverse the transforms in the chain. May not be implemented for
76
77
  /// all transforms in the chain or may return approximate results.
77
- void reverse_chain (idx_t n, const float* xt, float* x) const;
78
-
78
+ void reverse_chain(idx_t n, const float* xt, float* x) const;
79
79
 
80
- DistanceComputer * get_distance_computer() const override;
80
+ DistanceComputer* get_distance_computer() const override;
81
81
 
82
82
  /* standalone codec interface */
83
- size_t sa_code_size () const override;
84
- void sa_encode (idx_t n, const float *x,
85
- uint8_t *bytes) const override;
86
- void sa_decode (idx_t n, const uint8_t *bytes,
87
- float *x) const override;
83
+ size_t sa_code_size() const override;
84
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
88
86
 
89
87
  ~IndexPreTransform() override;
90
88
  };
91
89
 
92
-
93
90
  } // namespace faiss
@@ -5,63 +5,58 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexRefine.h>
10
9
 
10
+ #include <faiss/IndexFlat.h>
11
+ #include <faiss/impl/AuxIndexStructures.h>
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/utils/Heap.h>
11
14
  #include <faiss/utils/distances.h>
12
15
  #include <faiss/utils/utils.h>
13
- #include <faiss/utils/Heap.h>
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/AuxIndexStructures.h>
16
- #include <faiss/IndexFlat.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
-
21
-
22
19
  /***************************************************
23
20
  * IndexRefine
24
21
  ***************************************************/
25
22
 
26
- IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
- Index (base_index->d, base_index->metric_type),
28
- base_index (base_index),
29
- refine_index (refine_index)
30
- {
23
+ IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
24
+ : Index(base_index->d, base_index->metric_type),
25
+ base_index(base_index),
26
+ refine_index(refine_index) {
31
27
  own_fields = own_refine_index = false;
32
28
  if (refine_index != nullptr) {
33
- FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
- FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
29
+ FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
30
+ FAISS_THROW_IF_NOT(
31
+ base_index->metric_type == refine_index->metric_type);
35
32
  is_trained = base_index->is_trained && refine_index->is_trained;
36
- FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
33
+ FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
37
34
  } // other case is useful only to construct an IndexRefineFlat
38
35
  ntotal = base_index->ntotal;
39
36
  }
40
37
 
41
- IndexRefine::IndexRefine ():
42
- base_index(nullptr), refine_index(nullptr),
43
- own_fields(false), own_refine_index(false)
44
- {
45
- }
38
+ IndexRefine::IndexRefine()
39
+ : base_index(nullptr),
40
+ refine_index(nullptr),
41
+ own_fields(false),
42
+ own_refine_index(false) {}
46
43
 
47
- void IndexRefine::train (idx_t n, const float *x)
48
- {
49
- base_index->train (n, x);
50
- refine_index->train (n, x);
44
+ void IndexRefine::train(idx_t n, const float* x) {
45
+ base_index->train(n, x);
46
+ refine_index->train(n, x);
51
47
  is_trained = true;
52
48
  }
53
49
 
54
- void IndexRefine::add (idx_t n, const float *x) {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- base_index->add (n, x);
57
- refine_index->add (n, x);
50
+ void IndexRefine::add(idx_t n, const float* x) {
51
+ FAISS_THROW_IF_NOT(is_trained);
52
+ base_index->add(n, x);
53
+ refine_index->add(n, x);
58
54
  ntotal = refine_index->ntotal;
59
55
  }
60
56
 
61
- void IndexRefine::reset ()
62
- {
63
- base_index->reset ();
64
- refine_index->reset ();
57
+ void IndexRefine::reset() {
58
+ base_index->reset();
59
+ refine_index->reset();
65
60
  ntotal = 0;
66
61
  }
67
62
 
@@ -69,69 +64,72 @@ namespace {
69
64
 
70
65
  typedef faiss::Index::idx_t idx_t;
71
66
 
72
- template<class C>
73
- static void reorder_2_heaps (
74
- idx_t n,
75
- idx_t k, idx_t *labels, float *distances,
76
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
- {
67
+ template <class C>
68
+ static void reorder_2_heaps(
69
+ idx_t n,
70
+ idx_t k,
71
+ idx_t* labels,
72
+ float* distances,
73
+ idx_t k_base,
74
+ const idx_t* base_labels,
75
+ const float* base_distances) {
78
76
  #pragma omp parallel for
79
77
  for (idx_t i = 0; i < n; i++) {
80
- idx_t *idxo = labels + i * k;
81
- float *diso = distances + i * k;
82
- const idx_t *idxi = base_labels + i * k_base;
83
- const float *disi = base_distances + i * k_base;
78
+ idx_t* idxo = labels + i * k;
79
+ float* diso = distances + i * k;
80
+ const idx_t* idxi = base_labels + i * k_base;
81
+ const float* disi = base_distances + i * k_base;
84
82
 
85
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
83
+ heap_heapify<C>(k, diso, idxo, disi, idxi, k);
86
84
  if (k_base != k) { // add remaining elements
87
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
85
+ heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
88
86
  }
89
- heap_reorder<C> (k, diso, idxo);
87
+ heap_reorder<C>(k, diso, idxo);
90
88
  }
91
89
  }
92
90
 
93
-
94
91
  } // anonymous namespace
95
92
 
96
-
97
-
98
- void IndexRefine::search (
99
- idx_t n, const float *x, idx_t k,
100
- float *distances, idx_t *labels) const
101
- {
102
- FAISS_THROW_IF_NOT (is_trained);
103
- idx_t k_base = idx_t (k * k_factor);
104
- idx_t * base_labels = labels;
105
- float * base_distances = distances;
93
+ void IndexRefine::search(
94
+ idx_t n,
95
+ const float* x,
96
+ idx_t k,
97
+ float* distances,
98
+ idx_t* labels) const {
99
+ FAISS_THROW_IF_NOT(k > 0);
100
+
101
+ FAISS_THROW_IF_NOT(is_trained);
102
+ idx_t k_base = idx_t(k * k_factor);
103
+ idx_t* base_labels = labels;
104
+ float* base_distances = distances;
106
105
  ScopeDeleter<idx_t> del1;
107
106
  ScopeDeleter<float> del2;
108
107
 
109
108
  if (k != k_base) {
110
- base_labels = new idx_t [n * k_base];
111
- del1.set (base_labels);
112
- base_distances = new float [n * k_base];
113
- del2.set (base_distances);
109
+ base_labels = new idx_t[n * k_base];
110
+ del1.set(base_labels);
111
+ base_distances = new float[n * k_base];
112
+ del2.set(base_distances);
114
113
  }
115
114
 
116
- base_index->search (n, x, k_base, base_distances, base_labels);
115
+ base_index->search(n, x, k_base, base_distances, base_labels);
117
116
 
118
117
  for (int i = 0; i < n * k_base; i++)
119
- assert (base_labels[i] >= -1 &&
120
- base_labels[i] < ntotal);
118
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
121
119
 
122
- // parallelize over queries
120
+ // parallelize over queries
123
121
  #pragma omp parallel if (n > 1)
124
122
  {
125
123
  std::unique_ptr<DistanceComputer> dc(
126
- refine_index->get_distance_computer()
127
- );
124
+ refine_index->get_distance_computer());
128
125
  #pragma omp for
129
126
  for (idx_t i = 0; i < n; i++) {
130
127
  dc->set_query(x + i * d);
131
128
  idx_t ij = i * k_base;
132
129
  for (idx_t j = 0; j < k_base; j++) {
133
130
  idx_t idx = base_labels[ij];
134
- if (idx < 0) break;
131
+ if (idx < 0)
132
+ break;
135
133
  base_distances[ij] = (*dc)(idx);
136
134
  ij++;
137
135
  }
@@ -140,117 +138,103 @@ void IndexRefine::search (
140
138
 
141
139
  // sort and store result
142
140
  if (metric_type == METRIC_L2) {
143
- typedef CMax <float, idx_t> C;
144
- reorder_2_heaps<C> (
145
- n, k, labels, distances,
146
- k_base, base_labels, base_distances);
141
+ typedef CMax<float, idx_t> C;
142
+ reorder_2_heaps<C>(
143
+ n, k, labels, distances, k_base, base_labels, base_distances);
147
144
 
148
145
  } else if (metric_type == METRIC_INNER_PRODUCT) {
149
- typedef CMin <float, idx_t> C;
150
- reorder_2_heaps<C> (
151
- n, k, labels, distances,
152
- k_base, base_labels, base_distances);
146
+ typedef CMin<float, idx_t> C;
147
+ reorder_2_heaps<C>(
148
+ n, k, labels, distances, k_base, base_labels, base_distances);
153
149
  } else {
154
150
  FAISS_THROW_MSG("Metric type not supported");
155
151
  }
156
-
157
152
  }
158
153
 
159
- void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
- refine_index->reconstruct (key, recons);
154
+ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
+ refine_index->reconstruct(key, recons);
161
156
  }
162
157
 
163
-
164
-
165
-
166
- IndexRefine::~IndexRefine ()
167
- {
168
- if (own_fields) delete base_index;
169
- if (own_refine_index) delete refine_index;
158
+ IndexRefine::~IndexRefine() {
159
+ if (own_fields)
160
+ delete base_index;
161
+ if (own_refine_index)
162
+ delete refine_index;
170
163
  }
171
164
 
172
-
173
165
  /***************************************************
174
166
  * IndexRefineFlat
175
167
  ***************************************************/
176
168
 
177
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
- IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
- {
169
+ IndexRefineFlat::IndexRefineFlat(Index* base_index)
170
+ : IndexRefine(
171
+ base_index,
172
+ new IndexFlat(base_index->d, base_index->metric_type)) {
180
173
  is_trained = base_index->is_trained;
181
174
  own_refine_index = true;
182
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
- "base_index should be empty in the beginning");
175
+ FAISS_THROW_IF_NOT_MSG(
176
+ base_index->ntotal == 0,
177
+ "base_index should be empty in the beginning");
184
178
  }
185
179
 
186
-
187
- IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
- IndexRefine (base_index, nullptr)
189
- {
180
+ IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
181
+ : IndexRefine(base_index, nullptr) {
190
182
  is_trained = base_index->is_trained;
191
183
  refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
184
  own_refine_index = true;
193
- refine_index->add (base_index->ntotal, xb);
194
-
185
+ refine_index->add(base_index->ntotal, xb);
195
186
  }
196
187
 
197
- IndexRefineFlat::IndexRefineFlat():
198
- IndexRefine()
199
- {
188
+ IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
200
189
  own_refine_index = true;
201
190
  }
202
191
 
203
-
204
- void IndexRefineFlat::search (
205
- idx_t n, const float *x, idx_t k,
206
- float *distances, idx_t *labels) const
207
- {
208
- FAISS_THROW_IF_NOT (is_trained);
209
- idx_t k_base = idx_t (k * k_factor);
210
- idx_t * base_labels = labels;
211
- float * base_distances = distances;
192
+ void IndexRefineFlat::search(
193
+ idx_t n,
194
+ const float* x,
195
+ idx_t k,
196
+ float* distances,
197
+ idx_t* labels) const {
198
+ FAISS_THROW_IF_NOT(k > 0);
199
+
200
+ FAISS_THROW_IF_NOT(is_trained);
201
+ idx_t k_base = idx_t(k * k_factor);
202
+ idx_t* base_labels = labels;
203
+ float* base_distances = distances;
212
204
  ScopeDeleter<idx_t> del1;
213
205
  ScopeDeleter<float> del2;
214
206
 
215
207
  if (k != k_base) {
216
- base_labels = new idx_t [n * k_base];
217
- del1.set (base_labels);
218
- base_distances = new float [n * k_base];
219
- del2.set (base_distances);
208
+ base_labels = new idx_t[n * k_base];
209
+ del1.set(base_labels);
210
+ base_distances = new float[n * k_base];
211
+ del2.set(base_distances);
220
212
  }
221
213
 
222
- base_index->search (n, x, k_base, base_distances, base_labels);
214
+ base_index->search(n, x, k_base, base_distances, base_labels);
223
215
 
224
216
  for (int i = 0; i < n * k_base; i++)
225
- assert (base_labels[i] >= -1 &&
226
- base_labels[i] < ntotal);
217
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
227
218
 
228
219
  // compute refined distances
229
- auto rf = dynamic_cast<const IndexFlat *>(refine_index);
220
+ auto rf = dynamic_cast<const IndexFlat*>(refine_index);
230
221
  FAISS_THROW_IF_NOT(rf);
231
222
 
232
- rf->compute_distance_subset (
233
- n, x, k_base, base_distances, base_labels);
223
+ rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
234
224
 
235
225
  // sort and store result
236
226
  if (metric_type == METRIC_L2) {
237
- typedef CMax <float, idx_t> C;
238
- reorder_2_heaps<C> (
239
- n, k, labels, distances,
240
- k_base, base_labels, base_distances);
227
+ typedef CMax<float, idx_t> C;
228
+ reorder_2_heaps<C>(
229
+ n, k, labels, distances, k_base, base_labels, base_distances);
241
230
 
242
231
  } else if (metric_type == METRIC_INNER_PRODUCT) {
243
- typedef CMin <float, idx_t> C;
244
- reorder_2_heaps<C> (
245
- n, k, labels, distances,
246
- k_base, base_labels, base_distances);
232
+ typedef CMin<float, idx_t> C;
233
+ reorder_2_heaps<C>(
234
+ n, k, labels, distances, k_base, base_labels, base_distances);
247
235
  } else {
248
236
  FAISS_THROW_MSG("Metric type not supported");
249
237
  }
250
-
251
238
  }
252
239
 
253
-
254
-
255
-
256
240
  } // namespace faiss