faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,8 +9,6 @@
9
9
 
10
10
  #pragma once
11
11
 
12
-
13
-
14
12
  #include <faiss/Index.h>
15
13
  #include <faiss/VectorTransform.h>
16
14
 
@@ -18,21 +16,20 @@ namespace faiss {
18
16
 
19
17
  /** Index that applies a LinearTransform transform on vectors before
20
18
  * handing them over to a sub-index */
21
- struct IndexPreTransform: Index {
19
+ struct IndexPreTransform : Index {
20
+ std::vector<VectorTransform*> chain; ///! chain of tranforms
21
+ Index* index; ///! the sub-index
22
22
 
23
- std::vector<VectorTransform *> chain; ///! chain of tranforms
24
- Index * index; ///! the sub-index
23
+ bool own_fields; ///! whether pointers are deleted in destructor
25
24
 
26
- bool own_fields; ///! whether pointers are deleted in destructor
25
+ explicit IndexPreTransform(Index* index);
27
26
 
28
- explicit IndexPreTransform (Index *index);
29
-
30
- IndexPreTransform ();
27
+ IndexPreTransform();
31
28
 
32
29
  /// ltrans is the last transform before the index
33
- IndexPreTransform (VectorTransform * ltrans, Index * index);
30
+ IndexPreTransform(VectorTransform* ltrans, Index* index);
34
31
 
35
- void prepend_transform (VectorTransform * ltrans);
32
+ void prepend_transform(VectorTransform* ltrans);
36
33
 
37
34
  void train(idx_t n, const float* x) override;
38
35
 
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
47
44
  size_t remove_ids(const IDSelector& sel) override;
48
45
 
49
46
  void search(
50
- idx_t n,
51
- const float* x,
52
- idx_t k,
53
- float* distances,
54
- idx_t* labels) const override;
55
-
47
+ idx_t n,
48
+ const float* x,
49
+ idx_t k,
50
+ float* distances,
51
+ idx_t* labels) const override;
56
52
 
57
53
  /* range search, no attempt is done to change the radius */
58
- void range_search (idx_t n, const float* x, float radius,
59
- RangeSearchResult* result) const override;
54
+ void range_search(
55
+ idx_t n,
56
+ const float* x,
57
+ float radius,
58
+ RangeSearchResult* result) const override;
60
59
 
60
+ void reconstruct(idx_t key, float* recons) const override;
61
61
 
62
- void reconstruct (idx_t key, float * recons) const override;
62
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
63
 
64
- void reconstruct_n (idx_t i0, idx_t ni, float *recons)
65
- const override;
66
-
67
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
68
- float *distances, idx_t *labels,
69
- float *recons) const override;
64
+ void search_and_reconstruct(
65
+ idx_t n,
66
+ const float* x,
67
+ idx_t k,
68
+ float* distances,
69
+ idx_t* labels,
70
+ float* recons) const override;
70
71
 
71
72
  /// apply the transforms in the chain. The returned float * may be
72
73
  /// equal to x, otherwise it should be deallocated.
73
- const float * apply_chain (idx_t n, const float *x) const;
74
+ const float* apply_chain(idx_t n, const float* x) const;
74
75
 
75
76
  /// Reverse the transforms in the chain. May not be implemented for
76
77
  /// all transforms in the chain or may return approximate results.
77
- void reverse_chain (idx_t n, const float* xt, float* x) const;
78
-
78
+ void reverse_chain(idx_t n, const float* xt, float* x) const;
79
79
 
80
- DistanceComputer * get_distance_computer() const override;
80
+ DistanceComputer* get_distance_computer() const override;
81
81
 
82
82
  /* standalone codec interface */
83
- size_t sa_code_size () const override;
84
- void sa_encode (idx_t n, const float *x,
85
- uint8_t *bytes) const override;
86
- void sa_decode (idx_t n, const uint8_t *bytes,
87
- float *x) const override;
83
+ size_t sa_code_size() const override;
84
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
88
86
 
89
87
  ~IndexPreTransform() override;
90
88
  };
91
89
 
92
-
93
90
  } // namespace faiss
@@ -5,63 +5,58 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexRefine.h>
10
9
 
10
+ #include <faiss/IndexFlat.h>
11
+ #include <faiss/impl/AuxIndexStructures.h>
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/utils/Heap.h>
11
14
  #include <faiss/utils/distances.h>
12
15
  #include <faiss/utils/utils.h>
13
- #include <faiss/utils/Heap.h>
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/AuxIndexStructures.h>
16
- #include <faiss/IndexFlat.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
-
21
-
22
19
  /***************************************************
23
20
  * IndexRefine
24
21
  ***************************************************/
25
22
 
26
- IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
- Index (base_index->d, base_index->metric_type),
28
- base_index (base_index),
29
- refine_index (refine_index)
30
- {
23
+ IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
24
+ : Index(base_index->d, base_index->metric_type),
25
+ base_index(base_index),
26
+ refine_index(refine_index) {
31
27
  own_fields = own_refine_index = false;
32
28
  if (refine_index != nullptr) {
33
- FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
- FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
29
+ FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
30
+ FAISS_THROW_IF_NOT(
31
+ base_index->metric_type == refine_index->metric_type);
35
32
  is_trained = base_index->is_trained && refine_index->is_trained;
36
- FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
33
+ FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
37
34
  } // other case is useful only to construct an IndexRefineFlat
38
35
  ntotal = base_index->ntotal;
39
36
  }
40
37
 
41
- IndexRefine::IndexRefine ():
42
- base_index(nullptr), refine_index(nullptr),
43
- own_fields(false), own_refine_index(false)
44
- {
45
- }
38
+ IndexRefine::IndexRefine()
39
+ : base_index(nullptr),
40
+ refine_index(nullptr),
41
+ own_fields(false),
42
+ own_refine_index(false) {}
46
43
 
47
- void IndexRefine::train (idx_t n, const float *x)
48
- {
49
- base_index->train (n, x);
50
- refine_index->train (n, x);
44
+ void IndexRefine::train(idx_t n, const float* x) {
45
+ base_index->train(n, x);
46
+ refine_index->train(n, x);
51
47
  is_trained = true;
52
48
  }
53
49
 
54
- void IndexRefine::add (idx_t n, const float *x) {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- base_index->add (n, x);
57
- refine_index->add (n, x);
50
+ void IndexRefine::add(idx_t n, const float* x) {
51
+ FAISS_THROW_IF_NOT(is_trained);
52
+ base_index->add(n, x);
53
+ refine_index->add(n, x);
58
54
  ntotal = refine_index->ntotal;
59
55
  }
60
56
 
61
- void IndexRefine::reset ()
62
- {
63
- base_index->reset ();
64
- refine_index->reset ();
57
+ void IndexRefine::reset() {
58
+ base_index->reset();
59
+ refine_index->reset();
65
60
  ntotal = 0;
66
61
  }
67
62
 
@@ -69,69 +64,72 @@ namespace {
69
64
 
70
65
  typedef faiss::Index::idx_t idx_t;
71
66
 
72
- template<class C>
73
- static void reorder_2_heaps (
74
- idx_t n,
75
- idx_t k, idx_t *labels, float *distances,
76
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
- {
67
+ template <class C>
68
+ static void reorder_2_heaps(
69
+ idx_t n,
70
+ idx_t k,
71
+ idx_t* labels,
72
+ float* distances,
73
+ idx_t k_base,
74
+ const idx_t* base_labels,
75
+ const float* base_distances) {
78
76
  #pragma omp parallel for
79
77
  for (idx_t i = 0; i < n; i++) {
80
- idx_t *idxo = labels + i * k;
81
- float *diso = distances + i * k;
82
- const idx_t *idxi = base_labels + i * k_base;
83
- const float *disi = base_distances + i * k_base;
78
+ idx_t* idxo = labels + i * k;
79
+ float* diso = distances + i * k;
80
+ const idx_t* idxi = base_labels + i * k_base;
81
+ const float* disi = base_distances + i * k_base;
84
82
 
85
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
83
+ heap_heapify<C>(k, diso, idxo, disi, idxi, k);
86
84
  if (k_base != k) { // add remaining elements
87
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
85
+ heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
88
86
  }
89
- heap_reorder<C> (k, diso, idxo);
87
+ heap_reorder<C>(k, diso, idxo);
90
88
  }
91
89
  }
92
90
 
93
-
94
91
  } // anonymous namespace
95
92
 
96
-
97
-
98
- void IndexRefine::search (
99
- idx_t n, const float *x, idx_t k,
100
- float *distances, idx_t *labels) const
101
- {
102
- FAISS_THROW_IF_NOT (is_trained);
103
- idx_t k_base = idx_t (k * k_factor);
104
- idx_t * base_labels = labels;
105
- float * base_distances = distances;
93
+ void IndexRefine::search(
94
+ idx_t n,
95
+ const float* x,
96
+ idx_t k,
97
+ float* distances,
98
+ idx_t* labels) const {
99
+ FAISS_THROW_IF_NOT(k > 0);
100
+
101
+ FAISS_THROW_IF_NOT(is_trained);
102
+ idx_t k_base = idx_t(k * k_factor);
103
+ idx_t* base_labels = labels;
104
+ float* base_distances = distances;
106
105
  ScopeDeleter<idx_t> del1;
107
106
  ScopeDeleter<float> del2;
108
107
 
109
108
  if (k != k_base) {
110
- base_labels = new idx_t [n * k_base];
111
- del1.set (base_labels);
112
- base_distances = new float [n * k_base];
113
- del2.set (base_distances);
109
+ base_labels = new idx_t[n * k_base];
110
+ del1.set(base_labels);
111
+ base_distances = new float[n * k_base];
112
+ del2.set(base_distances);
114
113
  }
115
114
 
116
- base_index->search (n, x, k_base, base_distances, base_labels);
115
+ base_index->search(n, x, k_base, base_distances, base_labels);
117
116
 
118
117
  for (int i = 0; i < n * k_base; i++)
119
- assert (base_labels[i] >= -1 &&
120
- base_labels[i] < ntotal);
118
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
121
119
 
122
- // parallelize over queries
120
+ // parallelize over queries
123
121
  #pragma omp parallel if (n > 1)
124
122
  {
125
123
  std::unique_ptr<DistanceComputer> dc(
126
- refine_index->get_distance_computer()
127
- );
124
+ refine_index->get_distance_computer());
128
125
  #pragma omp for
129
126
  for (idx_t i = 0; i < n; i++) {
130
127
  dc->set_query(x + i * d);
131
128
  idx_t ij = i * k_base;
132
129
  for (idx_t j = 0; j < k_base; j++) {
133
130
  idx_t idx = base_labels[ij];
134
- if (idx < 0) break;
131
+ if (idx < 0)
132
+ break;
135
133
  base_distances[ij] = (*dc)(idx);
136
134
  ij++;
137
135
  }
@@ -140,117 +138,131 @@ void IndexRefine::search (
140
138
 
141
139
  // sort and store result
142
140
  if (metric_type == METRIC_L2) {
143
- typedef CMax <float, idx_t> C;
144
- reorder_2_heaps<C> (
145
- n, k, labels, distances,
146
- k_base, base_labels, base_distances);
141
+ typedef CMax<float, idx_t> C;
142
+ reorder_2_heaps<C>(
143
+ n, k, labels, distances, k_base, base_labels, base_distances);
147
144
 
148
145
  } else if (metric_type == METRIC_INNER_PRODUCT) {
149
- typedef CMin <float, idx_t> C;
150
- reorder_2_heaps<C> (
151
- n, k, labels, distances,
152
- k_base, base_labels, base_distances);
146
+ typedef CMin<float, idx_t> C;
147
+ reorder_2_heaps<C>(
148
+ n, k, labels, distances, k_base, base_labels, base_distances);
153
149
  } else {
154
150
  FAISS_THROW_MSG("Metric type not supported");
155
151
  }
156
-
157
152
  }
158
153
 
159
- void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
- refine_index->reconstruct (key, recons);
154
+ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
+ refine_index->reconstruct(key, recons);
161
156
  }
162
157
 
158
+ size_t IndexRefine::sa_code_size() const {
159
+ return base_index->sa_code_size() + refine_index->sa_code_size();
160
+ }
163
161
 
162
+ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
163
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
164
+ std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
165
+ base_index->sa_encode(n, x, tmp1.get());
166
+ std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
167
+ refine_index->sa_encode(n, x, tmp2.get());
168
+ for (size_t i = 0; i < n; i++) {
169
+ uint8_t* b = bytes + i * (cs1 + cs2);
170
+ memcpy(b, tmp1.get() + cs1 * i, cs1);
171
+ memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
172
+ }
173
+ }
164
174
 
175
+ void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
176
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
177
+ std::unique_ptr<uint8_t[]> tmp2(
178
+ new uint8_t[n * refine_index->sa_code_size()]);
179
+ for (size_t i = 0; i < n; i++) {
180
+ memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
181
+ }
165
182
 
166
- IndexRefine::~IndexRefine ()
167
- {
168
- if (own_fields) delete base_index;
169
- if (own_refine_index) delete refine_index;
183
+ refine_index->sa_decode(n, tmp2.get(), x);
170
184
  }
171
185
 
186
+ IndexRefine::~IndexRefine() {
187
+ if (own_fields)
188
+ delete base_index;
189
+ if (own_refine_index)
190
+ delete refine_index;
191
+ }
172
192
 
173
193
  /***************************************************
174
194
  * IndexRefineFlat
175
195
  ***************************************************/
176
196
 
177
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
- IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
- {
197
+ IndexRefineFlat::IndexRefineFlat(Index* base_index)
198
+ : IndexRefine(
199
+ base_index,
200
+ new IndexFlat(base_index->d, base_index->metric_type)) {
180
201
  is_trained = base_index->is_trained;
181
202
  own_refine_index = true;
182
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
- "base_index should be empty in the beginning");
203
+ FAISS_THROW_IF_NOT_MSG(
204
+ base_index->ntotal == 0,
205
+ "base_index should be empty in the beginning");
184
206
  }
185
207
 
186
-
187
- IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
- IndexRefine (base_index, nullptr)
189
- {
208
+ IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
209
+ : IndexRefine(base_index, nullptr) {
190
210
  is_trained = base_index->is_trained;
191
211
  refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
212
  own_refine_index = true;
193
- refine_index->add (base_index->ntotal, xb);
194
-
213
+ refine_index->add(base_index->ntotal, xb);
195
214
  }
196
215
 
197
- IndexRefineFlat::IndexRefineFlat():
198
- IndexRefine()
199
- {
216
+ IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
200
217
  own_refine_index = true;
201
218
  }
202
219
 
203
-
204
- void IndexRefineFlat::search (
205
- idx_t n, const float *x, idx_t k,
206
- float *distances, idx_t *labels) const
207
- {
208
- FAISS_THROW_IF_NOT (is_trained);
209
- idx_t k_base = idx_t (k * k_factor);
210
- idx_t * base_labels = labels;
211
- float * base_distances = distances;
220
+ void IndexRefineFlat::search(
221
+ idx_t n,
222
+ const float* x,
223
+ idx_t k,
224
+ float* distances,
225
+ idx_t* labels) const {
226
+ FAISS_THROW_IF_NOT(k > 0);
227
+
228
+ FAISS_THROW_IF_NOT(is_trained);
229
+ idx_t k_base = idx_t(k * k_factor);
230
+ idx_t* base_labels = labels;
231
+ float* base_distances = distances;
212
232
  ScopeDeleter<idx_t> del1;
213
233
  ScopeDeleter<float> del2;
214
234
 
215
235
  if (k != k_base) {
216
- base_labels = new idx_t [n * k_base];
217
- del1.set (base_labels);
218
- base_distances = new float [n * k_base];
219
- del2.set (base_distances);
236
+ base_labels = new idx_t[n * k_base];
237
+ del1.set(base_labels);
238
+ base_distances = new float[n * k_base];
239
+ del2.set(base_distances);
220
240
  }
221
241
 
222
- base_index->search (n, x, k_base, base_distances, base_labels);
242
+ base_index->search(n, x, k_base, base_distances, base_labels);
223
243
 
224
244
  for (int i = 0; i < n * k_base; i++)
225
- assert (base_labels[i] >= -1 &&
226
- base_labels[i] < ntotal);
245
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
227
246
 
228
247
  // compute refined distances
229
- auto rf = dynamic_cast<const IndexFlat *>(refine_index);
248
+ auto rf = dynamic_cast<const IndexFlat*>(refine_index);
230
249
  FAISS_THROW_IF_NOT(rf);
231
250
 
232
- rf->compute_distance_subset (
233
- n, x, k_base, base_distances, base_labels);
251
+ rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
234
252
 
235
253
  // sort and store result
236
254
  if (metric_type == METRIC_L2) {
237
- typedef CMax <float, idx_t> C;
238
- reorder_2_heaps<C> (
239
- n, k, labels, distances,
240
- k_base, base_labels, base_distances);
255
+ typedef CMax<float, idx_t> C;
256
+ reorder_2_heaps<C>(
257
+ n, k, labels, distances, k_base, base_labels, base_distances);
241
258
 
242
259
  } else if (metric_type == METRIC_INNER_PRODUCT) {
243
- typedef CMin <float, idx_t> C;
244
- reorder_2_heaps<C> (
245
- n, k, labels, distances,
246
- k_base, base_labels, base_distances);
260
+ typedef CMin<float, idx_t> C;
261
+ reorder_2_heaps<C>(
262
+ n, k, labels, distances, k_base, base_labels, base_distances);
247
263
  } else {
248
264
  FAISS_THROW_MSG("Metric type not supported");
249
265
  }
250
-
251
266
  }
252
267
 
253
-
254
-
255
-
256
268
  } // namespace faiss
@@ -9,32 +9,29 @@
9
9
 
10
10
  #include <faiss/Index.h>
11
11
 
12
-
13
12
  namespace faiss {
14
13
 
15
-
16
14
  /** Index that queries in a base_index (a fast one) and refines the
17
15
  * results with an exact search, hopefully improving the results.
18
16
  */
19
- struct IndexRefine: Index {
20
-
17
+ struct IndexRefine : Index {
21
18
  /// faster index to pre-select the vectors that should be filtered
22
- Index *base_index;
19
+ Index* base_index;
23
20
 
24
21
  /// refinement index
25
- Index *refine_index;
22
+ Index* refine_index;
26
23
 
27
- bool own_fields; ///< should the base index be deallocated?
28
- bool own_refine_index; ///< same with the refinement index
24
+ bool own_fields; ///< should the base index be deallocated?
25
+ bool own_refine_index; ///< same with the refinement index
29
26
 
30
27
  /// factor between k requested in search and the k requested from
31
28
  /// the base_index (should be >= 1)
32
29
  float k_factor = 1;
33
30
 
34
- /// intitialize from empty index
35
- IndexRefine (Index *base_index, Index *refine_index);
31
+ /// initialize from empty index
32
+ IndexRefine(Index* base_index, Index* refine_index);
36
33
 
37
- IndexRefine ();
34
+ IndexRefine();
38
35
 
39
36
  void train(idx_t n, const float* x) override;
40
37
 
@@ -43,31 +40,43 @@ struct IndexRefine: Index {
43
40
  void reset() override;
44
41
 
45
42
  void search(
46
- idx_t n, const float* x, idx_t k,
47
- float* distances, idx_t* labels) const override;
43
+ idx_t n,
44
+ const float* x,
45
+ idx_t k,
46
+ float* distances,
47
+ idx_t* labels) const override;
48
48
 
49
49
  // reconstruct is routed to the refine_index
50
- void reconstruct (idx_t key, float * recons) const override;
50
+ void reconstruct(idx_t key, float* recons) const override;
51
+
52
+ /* standalone codec interface: the base_index codes are interleaved with the
53
+ * refine_index ones */
54
+ size_t sa_code_size() const override;
55
+
56
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
57
+
58
+ /// The sa_decode decodes from the index_refine, which is assumed to be more
59
+ /// accurate
60
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
51
61
 
52
62
  ~IndexRefine() override;
53
63
  };
54
64
 
55
-
56
65
  /** Version where the refinement index is an IndexFlat. It has one additional
57
66
  * constructor that takes a table of elements to add to the flat refinement
58
67
  * index */
59
- struct IndexRefineFlat: IndexRefine {
60
- explicit IndexRefineFlat (Index *base_index);
61
- IndexRefineFlat(Index *base_index, const float *xb);
68
+ struct IndexRefineFlat : IndexRefine {
69
+ explicit IndexRefineFlat(Index* base_index);
70
+ IndexRefineFlat(Index* base_index, const float* xb);
62
71
 
63
72
  IndexRefineFlat();
64
73
 
65
74
  void search(
66
- idx_t n, const float* x, idx_t k,
67
- float* distances, idx_t* labels) const override;
68
-
75
+ idx_t n,
76
+ const float* x,
77
+ idx_t k,
78
+ float* distances,
79
+ idx_t* labels) const override;
69
80
  };
70
81
 
71
-
72
-
73
82
  } // namespace faiss