faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,8 +9,6 @@
9
9
 
10
10
  #pragma once
11
11
 
12
-
13
-
14
12
  #include <faiss/Index.h>
15
13
  #include <faiss/VectorTransform.h>
16
14
 
@@ -18,21 +16,20 @@ namespace faiss {
18
16
 
19
17
  /** Index that applies a LinearTransform transform on vectors before
20
18
  * handing them over to a sub-index */
21
- struct IndexPreTransform: Index {
19
+ struct IndexPreTransform : Index {
20
+ std::vector<VectorTransform*> chain; ///! chain of tranforms
21
+ Index* index; ///! the sub-index
22
22
 
23
- std::vector<VectorTransform *> chain; ///! chain of tranforms
24
- Index * index; ///! the sub-index
23
+ bool own_fields; ///! whether pointers are deleted in destructor
25
24
 
26
- bool own_fields; ///! whether pointers are deleted in destructor
25
+ explicit IndexPreTransform(Index* index);
27
26
 
28
- explicit IndexPreTransform (Index *index);
29
-
30
- IndexPreTransform ();
27
+ IndexPreTransform();
31
28
 
32
29
  /// ltrans is the last transform before the index
33
- IndexPreTransform (VectorTransform * ltrans, Index * index);
30
+ IndexPreTransform(VectorTransform* ltrans, Index* index);
34
31
 
35
- void prepend_transform (VectorTransform * ltrans);
32
+ void prepend_transform(VectorTransform* ltrans);
36
33
 
37
34
  void train(idx_t n, const float* x) override;
38
35
 
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
47
44
  size_t remove_ids(const IDSelector& sel) override;
48
45
 
49
46
  void search(
50
- idx_t n,
51
- const float* x,
52
- idx_t k,
53
- float* distances,
54
- idx_t* labels) const override;
55
-
47
+ idx_t n,
48
+ const float* x,
49
+ idx_t k,
50
+ float* distances,
51
+ idx_t* labels) const override;
56
52
 
57
53
  /* range search, no attempt is done to change the radius */
58
- void range_search (idx_t n, const float* x, float radius,
59
- RangeSearchResult* result) const override;
54
+ void range_search(
55
+ idx_t n,
56
+ const float* x,
57
+ float radius,
58
+ RangeSearchResult* result) const override;
60
59
 
60
+ void reconstruct(idx_t key, float* recons) const override;
61
61
 
62
- void reconstruct (idx_t key, float * recons) const override;
62
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
63
 
64
- void reconstruct_n (idx_t i0, idx_t ni, float *recons)
65
- const override;
66
-
67
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
68
- float *distances, idx_t *labels,
69
- float *recons) const override;
64
+ void search_and_reconstruct(
65
+ idx_t n,
66
+ const float* x,
67
+ idx_t k,
68
+ float* distances,
69
+ idx_t* labels,
70
+ float* recons) const override;
70
71
 
71
72
  /// apply the transforms in the chain. The returned float * may be
72
73
  /// equal to x, otherwise it should be deallocated.
73
- const float * apply_chain (idx_t n, const float *x) const;
74
+ const float* apply_chain(idx_t n, const float* x) const;
74
75
 
75
76
  /// Reverse the transforms in the chain. May not be implemented for
76
77
  /// all transforms in the chain or may return approximate results.
77
- void reverse_chain (idx_t n, const float* xt, float* x) const;
78
-
78
+ void reverse_chain(idx_t n, const float* xt, float* x) const;
79
79
 
80
- DistanceComputer * get_distance_computer() const override;
80
+ DistanceComputer* get_distance_computer() const override;
81
81
 
82
82
  /* standalone codec interface */
83
- size_t sa_code_size () const override;
84
- void sa_encode (idx_t n, const float *x,
85
- uint8_t *bytes) const override;
86
- void sa_decode (idx_t n, const uint8_t *bytes,
87
- float *x) const override;
83
+ size_t sa_code_size() const override;
84
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
88
86
 
89
87
  ~IndexPreTransform() override;
90
88
  };
91
89
 
92
-
93
90
  } // namespace faiss
@@ -5,63 +5,58 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexRefine.h>
10
9
 
10
+ #include <faiss/IndexFlat.h>
11
+ #include <faiss/impl/AuxIndexStructures.h>
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/utils/Heap.h>
11
14
  #include <faiss/utils/distances.h>
12
15
  #include <faiss/utils/utils.h>
13
- #include <faiss/utils/Heap.h>
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/AuxIndexStructures.h>
16
- #include <faiss/IndexFlat.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
-
21
-
22
19
  /***************************************************
23
20
  * IndexRefine
24
21
  ***************************************************/
25
22
 
26
- IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
- Index (base_index->d, base_index->metric_type),
28
- base_index (base_index),
29
- refine_index (refine_index)
30
- {
23
+ IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
24
+ : Index(base_index->d, base_index->metric_type),
25
+ base_index(base_index),
26
+ refine_index(refine_index) {
31
27
  own_fields = own_refine_index = false;
32
28
  if (refine_index != nullptr) {
33
- FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
- FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
29
+ FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
30
+ FAISS_THROW_IF_NOT(
31
+ base_index->metric_type == refine_index->metric_type);
35
32
  is_trained = base_index->is_trained && refine_index->is_trained;
36
- FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
33
+ FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
37
34
  } // other case is useful only to construct an IndexRefineFlat
38
35
  ntotal = base_index->ntotal;
39
36
  }
40
37
 
41
- IndexRefine::IndexRefine ():
42
- base_index(nullptr), refine_index(nullptr),
43
- own_fields(false), own_refine_index(false)
44
- {
45
- }
38
+ IndexRefine::IndexRefine()
39
+ : base_index(nullptr),
40
+ refine_index(nullptr),
41
+ own_fields(false),
42
+ own_refine_index(false) {}
46
43
 
47
- void IndexRefine::train (idx_t n, const float *x)
48
- {
49
- base_index->train (n, x);
50
- refine_index->train (n, x);
44
+ void IndexRefine::train(idx_t n, const float* x) {
45
+ base_index->train(n, x);
46
+ refine_index->train(n, x);
51
47
  is_trained = true;
52
48
  }
53
49
 
54
- void IndexRefine::add (idx_t n, const float *x) {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- base_index->add (n, x);
57
- refine_index->add (n, x);
50
+ void IndexRefine::add(idx_t n, const float* x) {
51
+ FAISS_THROW_IF_NOT(is_trained);
52
+ base_index->add(n, x);
53
+ refine_index->add(n, x);
58
54
  ntotal = refine_index->ntotal;
59
55
  }
60
56
 
61
- void IndexRefine::reset ()
62
- {
63
- base_index->reset ();
64
- refine_index->reset ();
57
+ void IndexRefine::reset() {
58
+ base_index->reset();
59
+ refine_index->reset();
65
60
  ntotal = 0;
66
61
  }
67
62
 
@@ -69,69 +64,72 @@ namespace {
69
64
 
70
65
  typedef faiss::Index::idx_t idx_t;
71
66
 
72
- template<class C>
73
- static void reorder_2_heaps (
74
- idx_t n,
75
- idx_t k, idx_t *labels, float *distances,
76
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
- {
67
+ template <class C>
68
+ static void reorder_2_heaps(
69
+ idx_t n,
70
+ idx_t k,
71
+ idx_t* labels,
72
+ float* distances,
73
+ idx_t k_base,
74
+ const idx_t* base_labels,
75
+ const float* base_distances) {
78
76
  #pragma omp parallel for
79
77
  for (idx_t i = 0; i < n; i++) {
80
- idx_t *idxo = labels + i * k;
81
- float *diso = distances + i * k;
82
- const idx_t *idxi = base_labels + i * k_base;
83
- const float *disi = base_distances + i * k_base;
78
+ idx_t* idxo = labels + i * k;
79
+ float* diso = distances + i * k;
80
+ const idx_t* idxi = base_labels + i * k_base;
81
+ const float* disi = base_distances + i * k_base;
84
82
 
85
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
83
+ heap_heapify<C>(k, diso, idxo, disi, idxi, k);
86
84
  if (k_base != k) { // add remaining elements
87
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
85
+ heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
88
86
  }
89
- heap_reorder<C> (k, diso, idxo);
87
+ heap_reorder<C>(k, diso, idxo);
90
88
  }
91
89
  }
92
90
 
93
-
94
91
  } // anonymous namespace
95
92
 
96
-
97
-
98
- void IndexRefine::search (
99
- idx_t n, const float *x, idx_t k,
100
- float *distances, idx_t *labels) const
101
- {
102
- FAISS_THROW_IF_NOT (is_trained);
103
- idx_t k_base = idx_t (k * k_factor);
104
- idx_t * base_labels = labels;
105
- float * base_distances = distances;
93
+ void IndexRefine::search(
94
+ idx_t n,
95
+ const float* x,
96
+ idx_t k,
97
+ float* distances,
98
+ idx_t* labels) const {
99
+ FAISS_THROW_IF_NOT(k > 0);
100
+
101
+ FAISS_THROW_IF_NOT(is_trained);
102
+ idx_t k_base = idx_t(k * k_factor);
103
+ idx_t* base_labels = labels;
104
+ float* base_distances = distances;
106
105
  ScopeDeleter<idx_t> del1;
107
106
  ScopeDeleter<float> del2;
108
107
 
109
108
  if (k != k_base) {
110
- base_labels = new idx_t [n * k_base];
111
- del1.set (base_labels);
112
- base_distances = new float [n * k_base];
113
- del2.set (base_distances);
109
+ base_labels = new idx_t[n * k_base];
110
+ del1.set(base_labels);
111
+ base_distances = new float[n * k_base];
112
+ del2.set(base_distances);
114
113
  }
115
114
 
116
- base_index->search (n, x, k_base, base_distances, base_labels);
115
+ base_index->search(n, x, k_base, base_distances, base_labels);
117
116
 
118
117
  for (int i = 0; i < n * k_base; i++)
119
- assert (base_labels[i] >= -1 &&
120
- base_labels[i] < ntotal);
118
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
121
119
 
122
- // parallelize over queries
120
+ // parallelize over queries
123
121
  #pragma omp parallel if (n > 1)
124
122
  {
125
123
  std::unique_ptr<DistanceComputer> dc(
126
- refine_index->get_distance_computer()
127
- );
124
+ refine_index->get_distance_computer());
128
125
  #pragma omp for
129
126
  for (idx_t i = 0; i < n; i++) {
130
127
  dc->set_query(x + i * d);
131
128
  idx_t ij = i * k_base;
132
129
  for (idx_t j = 0; j < k_base; j++) {
133
130
  idx_t idx = base_labels[ij];
134
- if (idx < 0) break;
131
+ if (idx < 0)
132
+ break;
135
133
  base_distances[ij] = (*dc)(idx);
136
134
  ij++;
137
135
  }
@@ -140,117 +138,131 @@ void IndexRefine::search (
140
138
 
141
139
  // sort and store result
142
140
  if (metric_type == METRIC_L2) {
143
- typedef CMax <float, idx_t> C;
144
- reorder_2_heaps<C> (
145
- n, k, labels, distances,
146
- k_base, base_labels, base_distances);
141
+ typedef CMax<float, idx_t> C;
142
+ reorder_2_heaps<C>(
143
+ n, k, labels, distances, k_base, base_labels, base_distances);
147
144
 
148
145
  } else if (metric_type == METRIC_INNER_PRODUCT) {
149
- typedef CMin <float, idx_t> C;
150
- reorder_2_heaps<C> (
151
- n, k, labels, distances,
152
- k_base, base_labels, base_distances);
146
+ typedef CMin<float, idx_t> C;
147
+ reorder_2_heaps<C>(
148
+ n, k, labels, distances, k_base, base_labels, base_distances);
153
149
  } else {
154
150
  FAISS_THROW_MSG("Metric type not supported");
155
151
  }
156
-
157
152
  }
158
153
 
159
- void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
- refine_index->reconstruct (key, recons);
154
+ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
+ refine_index->reconstruct(key, recons);
161
156
  }
162
157
 
158
+ size_t IndexRefine::sa_code_size() const {
159
+ return base_index->sa_code_size() + refine_index->sa_code_size();
160
+ }
163
161
 
162
+ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
163
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
164
+ std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
165
+ base_index->sa_encode(n, x, tmp1.get());
166
+ std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
167
+ refine_index->sa_encode(n, x, tmp2.get());
168
+ for (size_t i = 0; i < n; i++) {
169
+ uint8_t* b = bytes + i * (cs1 + cs2);
170
+ memcpy(b, tmp1.get() + cs1 * i, cs1);
171
+ memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
172
+ }
173
+ }
164
174
 
175
+ void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
176
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
177
+ std::unique_ptr<uint8_t[]> tmp2(
178
+ new uint8_t[n * refine_index->sa_code_size()]);
179
+ for (size_t i = 0; i < n; i++) {
180
+ memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
181
+ }
165
182
 
166
- IndexRefine::~IndexRefine ()
167
- {
168
- if (own_fields) delete base_index;
169
- if (own_refine_index) delete refine_index;
183
+ refine_index->sa_decode(n, tmp2.get(), x);
170
184
  }
171
185
 
186
+ IndexRefine::~IndexRefine() {
187
+ if (own_fields)
188
+ delete base_index;
189
+ if (own_refine_index)
190
+ delete refine_index;
191
+ }
172
192
 
173
193
  /***************************************************
174
194
  * IndexRefineFlat
175
195
  ***************************************************/
176
196
 
177
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
- IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
- {
197
+ IndexRefineFlat::IndexRefineFlat(Index* base_index)
198
+ : IndexRefine(
199
+ base_index,
200
+ new IndexFlat(base_index->d, base_index->metric_type)) {
180
201
  is_trained = base_index->is_trained;
181
202
  own_refine_index = true;
182
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
- "base_index should be empty in the beginning");
203
+ FAISS_THROW_IF_NOT_MSG(
204
+ base_index->ntotal == 0,
205
+ "base_index should be empty in the beginning");
184
206
  }
185
207
 
186
-
187
- IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
- IndexRefine (base_index, nullptr)
189
- {
208
+ IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
209
+ : IndexRefine(base_index, nullptr) {
190
210
  is_trained = base_index->is_trained;
191
211
  refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
212
  own_refine_index = true;
193
- refine_index->add (base_index->ntotal, xb);
194
-
213
+ refine_index->add(base_index->ntotal, xb);
195
214
  }
196
215
 
197
- IndexRefineFlat::IndexRefineFlat():
198
- IndexRefine()
199
- {
216
+ IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
200
217
  own_refine_index = true;
201
218
  }
202
219
 
203
-
204
- void IndexRefineFlat::search (
205
- idx_t n, const float *x, idx_t k,
206
- float *distances, idx_t *labels) const
207
- {
208
- FAISS_THROW_IF_NOT (is_trained);
209
- idx_t k_base = idx_t (k * k_factor);
210
- idx_t * base_labels = labels;
211
- float * base_distances = distances;
220
+ void IndexRefineFlat::search(
221
+ idx_t n,
222
+ const float* x,
223
+ idx_t k,
224
+ float* distances,
225
+ idx_t* labels) const {
226
+ FAISS_THROW_IF_NOT(k > 0);
227
+
228
+ FAISS_THROW_IF_NOT(is_trained);
229
+ idx_t k_base = idx_t(k * k_factor);
230
+ idx_t* base_labels = labels;
231
+ float* base_distances = distances;
212
232
  ScopeDeleter<idx_t> del1;
213
233
  ScopeDeleter<float> del2;
214
234
 
215
235
  if (k != k_base) {
216
- base_labels = new idx_t [n * k_base];
217
- del1.set (base_labels);
218
- base_distances = new float [n * k_base];
219
- del2.set (base_distances);
236
+ base_labels = new idx_t[n * k_base];
237
+ del1.set(base_labels);
238
+ base_distances = new float[n * k_base];
239
+ del2.set(base_distances);
220
240
  }
221
241
 
222
- base_index->search (n, x, k_base, base_distances, base_labels);
242
+ base_index->search(n, x, k_base, base_distances, base_labels);
223
243
 
224
244
  for (int i = 0; i < n * k_base; i++)
225
- assert (base_labels[i] >= -1 &&
226
- base_labels[i] < ntotal);
245
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
227
246
 
228
247
  // compute refined distances
229
- auto rf = dynamic_cast<const IndexFlat *>(refine_index);
248
+ auto rf = dynamic_cast<const IndexFlat*>(refine_index);
230
249
  FAISS_THROW_IF_NOT(rf);
231
250
 
232
- rf->compute_distance_subset (
233
- n, x, k_base, base_distances, base_labels);
251
+ rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
234
252
 
235
253
  // sort and store result
236
254
  if (metric_type == METRIC_L2) {
237
- typedef CMax <float, idx_t> C;
238
- reorder_2_heaps<C> (
239
- n, k, labels, distances,
240
- k_base, base_labels, base_distances);
255
+ typedef CMax<float, idx_t> C;
256
+ reorder_2_heaps<C>(
257
+ n, k, labels, distances, k_base, base_labels, base_distances);
241
258
 
242
259
  } else if (metric_type == METRIC_INNER_PRODUCT) {
243
- typedef CMin <float, idx_t> C;
244
- reorder_2_heaps<C> (
245
- n, k, labels, distances,
246
- k_base, base_labels, base_distances);
260
+ typedef CMin<float, idx_t> C;
261
+ reorder_2_heaps<C>(
262
+ n, k, labels, distances, k_base, base_labels, base_distances);
247
263
  } else {
248
264
  FAISS_THROW_MSG("Metric type not supported");
249
265
  }
250
-
251
266
  }
252
267
 
253
-
254
-
255
-
256
268
  } // namespace faiss
@@ -9,32 +9,29 @@
9
9
 
10
10
  #include <faiss/Index.h>
11
11
 
12
-
13
12
  namespace faiss {
14
13
 
15
-
16
14
  /** Index that queries in a base_index (a fast one) and refines the
17
15
  * results with an exact search, hopefully improving the results.
18
16
  */
19
- struct IndexRefine: Index {
20
-
17
+ struct IndexRefine : Index {
21
18
  /// faster index to pre-select the vectors that should be filtered
22
- Index *base_index;
19
+ Index* base_index;
23
20
 
24
21
  /// refinement index
25
- Index *refine_index;
22
+ Index* refine_index;
26
23
 
27
- bool own_fields; ///< should the base index be deallocated?
28
- bool own_refine_index; ///< same with the refinement index
24
+ bool own_fields; ///< should the base index be deallocated?
25
+ bool own_refine_index; ///< same with the refinement index
29
26
 
30
27
  /// factor between k requested in search and the k requested from
31
28
  /// the base_index (should be >= 1)
32
29
  float k_factor = 1;
33
30
 
34
- /// intitialize from empty index
35
- IndexRefine (Index *base_index, Index *refine_index);
31
+ /// initialize from empty index
32
+ IndexRefine(Index* base_index, Index* refine_index);
36
33
 
37
- IndexRefine ();
34
+ IndexRefine();
38
35
 
39
36
  void train(idx_t n, const float* x) override;
40
37
 
@@ -43,31 +40,43 @@ struct IndexRefine: Index {
43
40
  void reset() override;
44
41
 
45
42
  void search(
46
- idx_t n, const float* x, idx_t k,
47
- float* distances, idx_t* labels) const override;
43
+ idx_t n,
44
+ const float* x,
45
+ idx_t k,
46
+ float* distances,
47
+ idx_t* labels) const override;
48
48
 
49
49
  // reconstruct is routed to the refine_index
50
- void reconstruct (idx_t key, float * recons) const override;
50
+ void reconstruct(idx_t key, float* recons) const override;
51
+
52
+ /* standalone codec interface: the base_index codes are interleaved with the
53
+ * refine_index ones */
54
+ size_t sa_code_size() const override;
55
+
56
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
57
+
58
+ /// The sa_decode decodes from the index_refine, which is assumed to be more
59
+ /// accurate
60
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
51
61
 
52
62
  ~IndexRefine() override;
53
63
  };
54
64
 
55
-
56
65
  /** Version where the refinement index is an IndexFlat. It has one additional
57
66
  * constructor that takes a table of elements to add to the flat refinement
58
67
  * index */
59
- struct IndexRefineFlat: IndexRefine {
60
- explicit IndexRefineFlat (Index *base_index);
61
- IndexRefineFlat(Index *base_index, const float *xb);
68
+ struct IndexRefineFlat : IndexRefine {
69
+ explicit IndexRefineFlat(Index* base_index);
70
+ IndexRefineFlat(Index* base_index, const float* xb);
62
71
 
63
72
  IndexRefineFlat();
64
73
 
65
74
  void search(
66
- idx_t n, const float* x, idx_t k,
67
- float* distances, idx_t* labels) const override;
68
-
75
+ idx_t n,
76
+ const float* x,
77
+ idx_t k,
78
+ float* distances,
79
+ idx_t* labels) const override;
69
80
  };
70
81
 
71
-
72
-
73
82
  } // namespace faiss