faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -10,30 +10,25 @@
10
10
  #ifndef FAISS_BINARY_HASH_H
11
11
  #define FAISS_BINARY_HASH_H
12
12
 
13
-
14
-
15
- #include <vector>
16
13
  #include <unordered_map>
14
+ #include <vector>
17
15
 
18
16
  #include <faiss/IndexBinary.h>
19
17
  #include <faiss/IndexBinaryFlat.h>
20
18
  #include <faiss/impl/platform_macros.h>
21
19
  #include <faiss/utils/Heap.h>
22
20
 
23
-
24
21
  namespace faiss {
25
22
 
26
23
  struct RangeSearchResult;
27
24
 
28
-
29
25
  /** just uses the b first bits as a hash value */
30
26
  struct IndexBinaryHash : IndexBinary {
31
-
32
27
  struct InvertedList {
33
28
  std::vector<idx_t> ids;
34
29
  std::vector<uint8_t> vecs;
35
30
 
36
- void add (idx_t id, size_t code_size, const uint8_t *code);
31
+ void add(idx_t id, size_t code_size, const uint8_t* code);
37
32
  };
38
33
 
39
34
  using InvertedListMap = std::unordered_map<idx_t, InvertedList>;
@@ -47,49 +42,55 @@ struct IndexBinaryHash : IndexBinary {
47
42
 
48
43
  void reset() override;
49
44
 
50
- void add(idx_t n, const uint8_t *x) override;
45
+ void add(idx_t n, const uint8_t* x) override;
51
46
 
52
- void add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) override;
47
+ void add_with_ids(idx_t n, const uint8_t* x, const idx_t* xids) override;
53
48
 
54
- void range_search(idx_t n, const uint8_t *x, int radius,
55
- RangeSearchResult *result) const override;
49
+ void range_search(
50
+ idx_t n,
51
+ const uint8_t* x,
52
+ int radius,
53
+ RangeSearchResult* result) const override;
56
54
 
57
- void search(idx_t n, const uint8_t *x, idx_t k,
58
- int32_t *distances, idx_t *labels) const override;
55
+ void search(
56
+ idx_t n,
57
+ const uint8_t* x,
58
+ idx_t k,
59
+ int32_t* distances,
60
+ idx_t* labels) const override;
59
61
 
60
62
  void display() const;
61
63
  size_t hashtable_size() const;
62
-
63
64
  };
64
65
 
65
66
  struct IndexBinaryHashStats {
66
- size_t nq; // nb of queries run
67
- size_t n0; // nb of empty lists
68
- size_t nlist; // nb of non-empty inverted lists scanned
69
- size_t ndis; // nb of distancs computed
70
-
71
- IndexBinaryHashStats () {reset (); }
72
- void reset ();
67
+ size_t nq; // nb of queries run
68
+ size_t n0; // nb of empty lists
69
+ size_t nlist; // nb of non-empty inverted lists scanned
70
+ size_t ndis; // nb of distancs computed
71
+
72
+ IndexBinaryHashStats() {
73
+ reset();
74
+ }
75
+ void reset();
73
76
  };
74
77
 
75
78
  FAISS_API extern IndexBinaryHashStats indexBinaryHash_stats;
76
79
 
77
-
78
80
  /** just uses the b first bits as a hash value */
79
- struct IndexBinaryMultiHash: IndexBinary {
80
-
81
+ struct IndexBinaryMultiHash : IndexBinary {
81
82
  // where the vectors are actually stored
82
- IndexBinaryFlat *storage;
83
+ IndexBinaryFlat* storage;
83
84
  bool own_fields;
84
85
 
85
86
  // maps hash values to the ids that hash to them
86
- using Map = std::unordered_map<idx_t, std::vector<idx_t> >;
87
+ using Map = std::unordered_map<idx_t, std::vector<idx_t>>;
87
88
 
88
89
  // the different hashes, size nhash
89
90
  std::vector<Map> maps;
90
91
 
91
92
  int nhash; ///< nb of hash maps
92
- int b; ///< nb bits per hash map
93
+ int b; ///< nb bits per hash map
93
94
  int nflip; ///< nb bit flips to use at search time
94
95
 
95
96
  IndexBinaryMultiHash(int d, int nhash, int b);
@@ -100,18 +101,24 @@ struct IndexBinaryMultiHash: IndexBinary {
100
101
 
101
102
  void reset() override;
102
103
 
103
- void add(idx_t n, const uint8_t *x) override;
104
+ void add(idx_t n, const uint8_t* x) override;
104
105
 
105
- void range_search(idx_t n, const uint8_t *x, int radius,
106
- RangeSearchResult *result) const override;
106
+ void range_search(
107
+ idx_t n,
108
+ const uint8_t* x,
109
+ int radius,
110
+ RangeSearchResult* result) const override;
107
111
 
108
- void search(idx_t n, const uint8_t *x, idx_t k,
109
- int32_t *distances, idx_t *labels) const override;
112
+ void search(
113
+ idx_t n,
114
+ const uint8_t* x,
115
+ idx_t k,
116
+ int32_t* distances,
117
+ idx_t* labels) const override;
110
118
 
111
119
  size_t hashtable_size() const;
112
-
113
120
  };
114
121
 
115
- }
122
+ } // namespace faiss
116
123
 
117
124
  #endif
@@ -9,318 +9,344 @@
9
9
 
10
10
  #include <faiss/IndexBinaryIVF.h>
11
11
 
12
+ #include <omp.h>
12
13
  #include <cinttypes>
13
14
  #include <cstdio>
14
- #include <omp.h>
15
15
 
16
+ #include <algorithm>
16
17
  #include <memory>
17
18
 
18
-
19
- #include <faiss/utils/hamming.h>
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/impl/AuxIndexStructures.h>
22
- #include <faiss/impl/FaissAssert.h>
23
19
  #include <faiss/IndexFlat.h>
24
20
  #include <faiss/IndexLSH.h>
25
-
21
+ #include <faiss/impl/AuxIndexStructures.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/utils.h>
26
25
 
27
26
  namespace faiss {
28
27
 
29
- IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist)
30
- : IndexBinary(d),
31
- invlists(new ArrayInvertedLists(nlist, code_size)),
32
- own_invlists(true),
33
- nprobe(1),
34
- max_codes(0),
35
- quantizer(quantizer),
36
- nlist(nlist),
37
- own_fields(false),
38
- clustering_index(nullptr)
39
- {
40
- FAISS_THROW_IF_NOT (d == quantizer->d);
41
- is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
42
-
43
- cp.niter = 10;
28
+ IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
29
+ : IndexBinary(d),
30
+ invlists(new ArrayInvertedLists(nlist, code_size)),
31
+ own_invlists(true),
32
+ nprobe(1),
33
+ max_codes(0),
34
+ quantizer(quantizer),
35
+ nlist(nlist),
36
+ own_fields(false),
37
+ clustering_index(nullptr) {
38
+ FAISS_THROW_IF_NOT(d == quantizer->d);
39
+ is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
40
+
41
+ cp.niter = 10;
44
42
  }
45
43
 
46
44
  IndexBinaryIVF::IndexBinaryIVF()
47
- : invlists(nullptr),
48
- own_invlists(false),
49
- nprobe(1),
50
- max_codes(0),
51
- quantizer(nullptr),
52
- nlist(0),
53
- own_fields(false),
54
- clustering_index(nullptr)
55
- {}
56
-
57
- void IndexBinaryIVF::add(idx_t n, const uint8_t *x) {
58
- add_with_ids(n, x, nullptr);
45
+ : invlists(nullptr),
46
+ own_invlists(false),
47
+ nprobe(1),
48
+ max_codes(0),
49
+ quantizer(nullptr),
50
+ nlist(0),
51
+ own_fields(false),
52
+ clustering_index(nullptr) {}
53
+
54
+ void IndexBinaryIVF::add(idx_t n, const uint8_t* x) {
55
+ add_with_ids(n, x, nullptr);
59
56
  }
60
57
 
61
- void IndexBinaryIVF::add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) {
62
- add_core(n, x, xids, nullptr);
58
+ void IndexBinaryIVF::add_with_ids(
59
+ idx_t n,
60
+ const uint8_t* x,
61
+ const idx_t* xids) {
62
+ add_core(n, x, xids, nullptr);
63
63
  }
64
64
 
65
- void IndexBinaryIVF::add_core(idx_t n, const uint8_t *x, const idx_t *xids,
66
- const idx_t *precomputed_idx) {
67
- FAISS_THROW_IF_NOT(is_trained);
68
- assert(invlists);
69
- direct_map.check_can_add (xids);
65
+ void IndexBinaryIVF::add_core(
66
+ idx_t n,
67
+ const uint8_t* x,
68
+ const idx_t* xids,
69
+ const idx_t* precomputed_idx) {
70
+ FAISS_THROW_IF_NOT(is_trained);
71
+ assert(invlists);
72
+ direct_map.check_can_add(xids);
70
73
 
71
- const idx_t * idx;
74
+ const idx_t* idx;
72
75
 
73
- std::unique_ptr<idx_t[]> scoped_idx;
76
+ std::unique_ptr<idx_t[]> scoped_idx;
74
77
 
75
- if (precomputed_idx) {
76
- idx = precomputed_idx;
77
- } else {
78
- scoped_idx.reset(new idx_t[n]);
79
- quantizer->assign(n, x, scoped_idx.get());
80
- idx = scoped_idx.get();
81
- }
78
+ if (precomputed_idx) {
79
+ idx = precomputed_idx;
80
+ } else {
81
+ scoped_idx.reset(new idx_t[n]);
82
+ quantizer->assign(n, x, scoped_idx.get());
83
+ idx = scoped_idx.get();
84
+ }
82
85
 
83
- long n_add = 0;
84
- for (size_t i = 0; i < n; i++) {
85
- idx_t id = xids ? xids[i] : ntotal + i;
86
- idx_t list_no = idx[i];
86
+ idx_t n_add = 0;
87
+ for (size_t i = 0; i < n; i++) {
88
+ idx_t id = xids ? xids[i] : ntotal + i;
89
+ idx_t list_no = idx[i];
87
90
 
88
- if (list_no < 0) {
89
- direct_map.add_single_id (id, -1, 0);
90
- } else {
91
- const uint8_t *xi = x + i * code_size;
92
- size_t offset = invlists->add_entry(list_no, id, xi);
91
+ if (list_no < 0) {
92
+ direct_map.add_single_id(id, -1, 0);
93
+ } else {
94
+ const uint8_t* xi = x + i * code_size;
95
+ size_t offset = invlists->add_entry(list_no, id, xi);
93
96
 
94
- direct_map.add_single_id (id, list_no, offset);
95
- }
97
+ direct_map.add_single_id(id, list_no, offset);
98
+ }
96
99
 
97
- n_add++;
98
- }
99
- if (verbose) {
100
- printf("IndexBinaryIVF::add_with_ids: added %ld / %" PRId64 " vectors\n",
101
- n_add, n);
102
- }
103
- ntotal += n_add;
100
+ n_add++;
101
+ }
102
+ if (verbose) {
103
+ printf("IndexBinaryIVF::add_with_ids: added "
104
+ "%" PRId64 " / %" PRId64 " vectors\n",
105
+ n_add,
106
+ n);
107
+ }
108
+ ntotal += n_add;
104
109
  }
105
110
 
106
- void IndexBinaryIVF::make_direct_map (bool b)
107
- {
111
+ void IndexBinaryIVF::make_direct_map(bool b) {
108
112
  if (b) {
109
- direct_map.set_type (DirectMap::Array, invlists, ntotal);
113
+ direct_map.set_type(DirectMap::Array, invlists, ntotal);
110
114
  } else {
111
- direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
115
+ direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
112
116
  }
113
117
  }
114
118
 
115
- void IndexBinaryIVF::set_direct_map_type (DirectMap::Type type)
116
- {
117
- direct_map.set_type (type, invlists, ntotal);
119
+ void IndexBinaryIVF::set_direct_map_type(DirectMap::Type type) {
120
+ direct_map.set_type(type, invlists, ntotal);
118
121
  }
119
122
 
123
+ void IndexBinaryIVF::search(
124
+ idx_t n,
125
+ const uint8_t* x,
126
+ idx_t k,
127
+ int32_t* distances,
128
+ idx_t* labels) const {
129
+ FAISS_THROW_IF_NOT(k > 0);
130
+ FAISS_THROW_IF_NOT(nprobe > 0);
120
131
 
121
- void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k,
122
- int32_t *distances, idx_t *labels) const {
123
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
124
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
132
+ const size_t nprobe = std::min(nlist, this->nprobe);
133
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
134
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
125
135
 
126
- double t0 = getmillisecs();
127
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
128
- indexIVF_stats.quantization_time += getmillisecs() - t0;
136
+ double t0 = getmillisecs();
137
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
138
+ indexIVF_stats.quantization_time += getmillisecs() - t0;
129
139
 
130
- t0 = getmillisecs();
131
- invlists->prefetch_lists(idx.get(), n * nprobe);
140
+ t0 = getmillisecs();
141
+ invlists->prefetch_lists(idx.get(), n * nprobe);
132
142
 
133
- search_preassigned(n, x, k, idx.get(), coarse_dis.get(),
134
- distances, labels, false);
135
- indexIVF_stats.search_time += getmillisecs() - t0;
143
+ search_preassigned(
144
+ n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
145
+ indexIVF_stats.search_time += getmillisecs() - t0;
136
146
  }
137
147
 
138
- void IndexBinaryIVF::reconstruct(idx_t key, uint8_t *recons) const {
139
- idx_t lo = direct_map.get (key);
140
- reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
148
+ void IndexBinaryIVF::reconstruct(idx_t key, uint8_t* recons) const {
149
+ idx_t lo = direct_map.get(key);
150
+ reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
141
151
  }
142
152
 
143
- void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t *recons) const {
144
- FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
153
+ void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
154
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
145
155
 
146
- for (idx_t list_no = 0; list_no < nlist; list_no++) {
147
- size_t list_size = invlists->list_size(list_no);
148
- const Index::idx_t *idlist = invlists->get_ids(list_no);
156
+ for (idx_t list_no = 0; list_no < nlist; list_no++) {
157
+ size_t list_size = invlists->list_size(list_no);
158
+ const Index::idx_t* idlist = invlists->get_ids(list_no);
149
159
 
150
- for (idx_t offset = 0; offset < list_size; offset++) {
151
- idx_t id = idlist[offset];
152
- if (!(id >= i0 && id < i0 + ni)) {
153
- continue;
154
- }
160
+ for (idx_t offset = 0; offset < list_size; offset++) {
161
+ idx_t id = idlist[offset];
162
+ if (!(id >= i0 && id < i0 + ni)) {
163
+ continue;
164
+ }
155
165
 
156
- uint8_t *reconstructed = recons + (id - i0) * d;
157
- reconstruct_from_offset(list_no, offset, reconstructed);
166
+ uint8_t* reconstructed = recons + (id - i0) * d;
167
+ reconstruct_from_offset(list_no, offset, reconstructed);
168
+ }
158
169
  }
159
- }
160
170
  }
161
171
 
162
- void IndexBinaryIVF::search_and_reconstruct(idx_t n, const uint8_t *x, idx_t k,
163
- int32_t *distances, idx_t *labels,
164
- uint8_t *recons) const {
165
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
166
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
167
-
168
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
169
-
170
- invlists->prefetch_lists(idx.get(), n * nprobe);
171
-
172
- // search_preassigned() with `store_pairs` enabled to obtain the list_no
173
- // and offset into `codes` for reconstruction
174
- search_preassigned(n, x, k, idx.get(), coarse_dis.get(),
175
- distances, labels, /* store_pairs */true);
176
- for (idx_t i = 0; i < n; ++i) {
177
- for (idx_t j = 0; j < k; ++j) {
178
- idx_t ij = i * k + j;
179
- idx_t key = labels[ij];
180
- uint8_t *reconstructed = recons + ij * d;
181
- if (key < 0) {
182
- // Fill with NaNs
183
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
184
- } else {
185
- int list_no = key >> 32;
186
- int offset = key & 0xffffffff;
187
-
188
- // Update label to the actual id
189
- labels[ij] = invlists->get_single_id(list_no, offset);
190
-
191
- reconstruct_from_offset(list_no, offset, reconstructed);
192
- }
172
+ void IndexBinaryIVF::search_and_reconstruct(
173
+ idx_t n,
174
+ const uint8_t* x,
175
+ idx_t k,
176
+ int32_t* distances,
177
+ idx_t* labels,
178
+ uint8_t* recons) const {
179
+ const size_t nprobe = std::min(nlist, this->nprobe);
180
+ FAISS_THROW_IF_NOT(k > 0);
181
+ FAISS_THROW_IF_NOT(nprobe > 0);
182
+
183
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
184
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
185
+
186
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
187
+
188
+ invlists->prefetch_lists(idx.get(), n * nprobe);
189
+
190
+ // search_preassigned() with `store_pairs` enabled to obtain the list_no
191
+ // and offset into `codes` for reconstruction
192
+ search_preassigned(
193
+ n,
194
+ x,
195
+ k,
196
+ idx.get(),
197
+ coarse_dis.get(),
198
+ distances,
199
+ labels,
200
+ /* store_pairs */ true);
201
+ for (idx_t i = 0; i < n; ++i) {
202
+ for (idx_t j = 0; j < k; ++j) {
203
+ idx_t ij = i * k + j;
204
+ idx_t key = labels[ij];
205
+ uint8_t* reconstructed = recons + ij * d;
206
+ if (key < 0) {
207
+ // Fill with NaNs
208
+ memset(reconstructed, -1, sizeof(*reconstructed) * d);
209
+ } else {
210
+ int list_no = key >> 32;
211
+ int offset = key & 0xffffffff;
212
+
213
+ // Update label to the actual id
214
+ labels[ij] = invlists->get_single_id(list_no, offset);
215
+
216
+ reconstruct_from_offset(list_no, offset, reconstructed);
217
+ }
218
+ }
193
219
  }
194
- }
195
220
  }
196
221
 
197
- void IndexBinaryIVF::reconstruct_from_offset(idx_t list_no, idx_t offset,
198
- uint8_t *recons) const {
199
- memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
222
+ void IndexBinaryIVF::reconstruct_from_offset(
223
+ idx_t list_no,
224
+ idx_t offset,
225
+ uint8_t* recons) const {
226
+ memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
200
227
  }
201
228
 
202
229
  void IndexBinaryIVF::reset() {
203
- direct_map.clear();
204
- invlists->reset();
205
- ntotal = 0;
230
+ direct_map.clear();
231
+ invlists->reset();
232
+ ntotal = 0;
206
233
  }
207
234
 
208
235
  size_t IndexBinaryIVF::remove_ids(const IDSelector& sel) {
209
- size_t nremove = direct_map.remove_ids (sel, invlists);
236
+ size_t nremove = direct_map.remove_ids(sel, invlists);
210
237
  ntotal -= nremove;
211
238
  return nremove;
212
239
  }
213
240
 
214
- void IndexBinaryIVF::train(idx_t n, const uint8_t *x) {
215
- if (verbose) {
216
- printf("Training quantizer\n");
217
- }
218
-
219
- if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
220
- if (verbose) {
221
- printf("IVF quantizer does not need training.\n");
222
- }
223
- } else {
241
+ void IndexBinaryIVF::train(idx_t n, const uint8_t* x) {
224
242
  if (verbose) {
225
- printf("Training quantizer on %" PRId64 " vectors in %dD\n", n, d);
243
+ printf("Training quantizer\n");
226
244
  }
227
245
 
228
- Clustering clus(d, nlist, cp);
229
- quantizer->reset();
230
-
231
- IndexFlatL2 index_tmp(d);
232
-
233
- if (clustering_index && verbose) {
234
- printf("using clustering_index of dimension %d to do the clustering\n",
235
- clustering_index->d);
236
- }
246
+ if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
247
+ if (verbose) {
248
+ printf("IVF quantizer does not need training.\n");
249
+ }
250
+ } else {
251
+ if (verbose) {
252
+ printf("Training quantizer on %" PRId64 " vectors in %dD\n", n, d);
253
+ }
237
254
 
238
- // LSH codec that is able to convert the binary vectors to floats.
239
- IndexLSH codec(d, d, false, false);
255
+ Clustering clus(d, nlist, cp);
256
+ quantizer->reset();
240
257
 
241
- clus.train_encoded (n, x, &codec, clustering_index ? *clustering_index : index_tmp);
258
+ IndexFlatL2 index_tmp(d);
242
259
 
243
- // convert clusters to binary
244
- std::unique_ptr<uint8_t[]> x_b(new uint8_t[clus.k * code_size]);
245
- real_to_binary(d * clus.k, clus.centroids.data(), x_b.get());
260
+ if (clustering_index && verbose) {
261
+ printf("using clustering_index of dimension %d to do the clustering\n",
262
+ clustering_index->d);
263
+ }
246
264
 
247
- quantizer->add(clus.k, x_b.get());
248
- quantizer->is_trained = true;
249
- }
265
+ // LSH codec that is able to convert the binary vectors to floats.
266
+ IndexLSH codec(d, d, false, false);
250
267
 
251
- is_trained = true;
252
- }
268
+ clus.train_encoded(
269
+ n, x, &codec, clustering_index ? *clustering_index : index_tmp);
253
270
 
254
- void IndexBinaryIVF::merge_from(IndexBinaryIVF &other, idx_t add_id) {
255
- // minimal sanity checks
256
- FAISS_THROW_IF_NOT(other.d == d);
257
- FAISS_THROW_IF_NOT(other.nlist == nlist);
258
- FAISS_THROW_IF_NOT(other.code_size == code_size);
259
- FAISS_THROW_IF_NOT_MSG(direct_map.no() && other.direct_map.no(),
260
- "direct map copy not implemented");
261
- FAISS_THROW_IF_NOT_MSG(typeid (*this) == typeid (other),
262
- "can only merge indexes of the same type");
271
+ // convert clusters to binary
272
+ std::unique_ptr<uint8_t[]> x_b(new uint8_t[clus.k * code_size]);
273
+ real_to_binary(d * clus.k, clus.centroids.data(), x_b.get());
263
274
 
264
- invlists->merge_from (other.invlists, add_id);
275
+ quantizer->add(clus.k, x_b.get());
276
+ quantizer->is_trained = true;
277
+ }
265
278
 
266
- ntotal += other.ntotal;
267
- other.ntotal = 0;
279
+ is_trained = true;
268
280
  }
269
281
 
270
- void IndexBinaryIVF::replace_invlists(InvertedLists *il, bool own) {
271
- FAISS_THROW_IF_NOT(il->nlist == nlist &&
272
- il->code_size == code_size);
273
- if (own_invlists) {
274
- delete invlists;
275
- }
276
- invlists = il;
277
- own_invlists = own;
282
+ void IndexBinaryIVF::merge_from(IndexBinaryIVF& other, idx_t add_id) {
283
+ // minimal sanity checks
284
+ FAISS_THROW_IF_NOT(other.d == d);
285
+ FAISS_THROW_IF_NOT(other.nlist == nlist);
286
+ FAISS_THROW_IF_NOT(other.code_size == code_size);
287
+ FAISS_THROW_IF_NOT_MSG(
288
+ direct_map.no() && other.direct_map.no(),
289
+ "direct map copy not implemented");
290
+ FAISS_THROW_IF_NOT_MSG(
291
+ typeid(*this) == typeid(other),
292
+ "can only merge indexes of the same type");
293
+
294
+ invlists->merge_from(other.invlists, add_id);
295
+
296
+ ntotal += other.ntotal;
297
+ other.ntotal = 0;
278
298
  }
279
299
 
300
+ void IndexBinaryIVF::replace_invlists(InvertedLists* il, bool own) {
301
+ FAISS_THROW_IF_NOT(il->nlist == nlist && il->code_size == code_size);
302
+ if (own_invlists) {
303
+ delete invlists;
304
+ }
305
+ invlists = il;
306
+ own_invlists = own;
307
+ }
280
308
 
281
309
  namespace {
282
310
 
283
311
  using idx_t = Index::idx_t;
284
312
 
285
-
286
- template<class HammingComputer>
287
- struct IVFBinaryScannerL2: BinaryInvertedListScanner {
288
-
313
+ template <class HammingComputer>
314
+ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
289
315
  HammingComputer hc;
290
316
  size_t code_size;
291
317
  bool store_pairs;
292
318
 
293
- IVFBinaryScannerL2 (size_t code_size, bool store_pairs):
294
- code_size (code_size), store_pairs(store_pairs)
295
- {}
319
+ IVFBinaryScannerL2(size_t code_size, bool store_pairs)
320
+ : code_size(code_size), store_pairs(store_pairs) {}
296
321
 
297
- void set_query (const uint8_t *query_vector) override {
298
- hc.set (query_vector, code_size);
322
+ void set_query(const uint8_t* query_vector) override {
323
+ hc.set(query_vector, code_size);
299
324
  }
300
325
 
301
326
  idx_t list_no;
302
- void set_list (idx_t list_no, uint8_t /* coarse_dis */) override {
327
+ void set_list(idx_t list_no, uint8_t /* coarse_dis */) override {
303
328
  this->list_no = list_no;
304
329
  }
305
330
 
306
- uint32_t distance_to_code (const uint8_t *code) const override {
307
- return hc.hamming (code);
331
+ uint32_t distance_to_code(const uint8_t* code) const override {
332
+ return hc.hamming(code);
308
333
  }
309
334
 
310
- size_t scan_codes (size_t n,
311
- const uint8_t *codes,
312
- const idx_t *ids,
313
- int32_t *simi, idx_t *idxi,
314
- size_t k) const override
315
- {
335
+ size_t scan_codes(
336
+ size_t n,
337
+ const uint8_t* codes,
338
+ const idx_t* ids,
339
+ int32_t* simi,
340
+ idx_t* idxi,
341
+ size_t k) const override {
316
342
  using C = CMax<int32_t, idx_t>;
317
343
 
318
344
  size_t nup = 0;
319
345
  for (size_t j = 0; j < n; j++) {
320
- uint32_t dis = hc.hamming (codes);
346
+ uint32_t dis = hc.hamming(codes);
321
347
  if (dis < simi[0]) {
322
348
  idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
323
- heap_replace_top<C> (k, simi, idxi, dis, id);
349
+ heap_replace_top<C>(k, simi, idxi, dis, id);
324
350
  nup++;
325
351
  }
326
352
  codes += code_size;
@@ -328,40 +354,38 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
328
354
  return nup;
329
355
  }
330
356
 
331
- void scan_codes_range (size_t n,
332
- const uint8_t *codes,
333
- const idx_t *ids,
334
- int radius,
335
- RangeQueryResult &result) const override
336
- {
357
+ void scan_codes_range(
358
+ size_t n,
359
+ const uint8_t* codes,
360
+ const idx_t* ids,
361
+ int radius,
362
+ RangeQueryResult& result) const override {
337
363
  size_t nup = 0;
338
364
  for (size_t j = 0; j < n; j++) {
339
- uint32_t dis = hc.hamming (codes);
365
+ uint32_t dis = hc.hamming(codes);
340
366
  if (dis < radius) {
341
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
342
- result.add (dis, id);
367
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
368
+ result.add(dis, id);
343
369
  }
344
370
  codes += code_size;
345
371
  }
346
-
347
372
  }
348
-
349
-
350
373
  };
351
374
 
352
-
353
- void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
354
- size_t n,
355
- const uint8_t *x,
356
- idx_t k,
357
- const idx_t *keys,
358
- const int32_t * coarse_dis,
359
- int32_t *distances, idx_t *labels,
360
- bool store_pairs,
361
- const IVFSearchParameters *params)
362
- {
363
- long nprobe = params ? params->nprobe : ivf.nprobe;
364
- long max_codes = params ? params->max_codes : ivf.max_codes;
375
+ void search_knn_hamming_heap(
376
+ const IndexBinaryIVF& ivf,
377
+ size_t n,
378
+ const uint8_t* x,
379
+ idx_t k,
380
+ const idx_t* keys,
381
+ const int32_t* coarse_dis,
382
+ int32_t* distances,
383
+ idx_t* labels,
384
+ bool store_pairs,
385
+ const IVFSearchParameters* params) {
386
+ idx_t nprobe = params ? params->nprobe : ivf.nprobe;
387
+ nprobe = std::min((idx_t)ivf.nlist, nprobe);
388
+ idx_t max_codes = params ? params->max_codes : ivf.max_codes;
365
389
  MetricType metric_type = ivf.metric_type;
366
390
 
367
391
  // almost verbatim copy from IndexIVF::search_preassigned
@@ -370,57 +394,57 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
370
394
  using HeapForIP = CMin<int32_t, idx_t>;
371
395
  using HeapForL2 = CMax<int32_t, idx_t>;
372
396
 
373
- #pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap)
397
+ #pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap)
374
398
  {
375
- std::unique_ptr<BinaryInvertedListScanner> scanner
376
- (ivf.get_InvertedListScanner (store_pairs));
399
+ std::unique_ptr<BinaryInvertedListScanner> scanner(
400
+ ivf.get_InvertedListScanner(store_pairs));
377
401
 
378
402
  #pragma omp for
379
403
  for (idx_t i = 0; i < n; i++) {
380
- const uint8_t *xi = x + i * ivf.code_size;
404
+ const uint8_t* xi = x + i * ivf.code_size;
381
405
  scanner->set_query(xi);
382
406
 
383
- const idx_t * keysi = keys + i * nprobe;
384
- int32_t * simi = distances + k * i;
385
- idx_t * idxi = labels + k * i;
407
+ const idx_t* keysi = keys + i * nprobe;
408
+ int32_t* simi = distances + k * i;
409
+ idx_t* idxi = labels + k * i;
386
410
 
387
411
  if (metric_type == METRIC_INNER_PRODUCT) {
388
- heap_heapify<HeapForIP> (k, simi, idxi);
412
+ heap_heapify<HeapForIP>(k, simi, idxi);
389
413
  } else {
390
- heap_heapify<HeapForL2> (k, simi, idxi);
414
+ heap_heapify<HeapForL2>(k, simi, idxi);
391
415
  }
392
416
 
393
417
  size_t nscan = 0;
394
418
 
395
419
  for (size_t ik = 0; ik < nprobe; ik++) {
396
- idx_t key = keysi[ik]; /* select the list */
420
+ idx_t key = keysi[ik]; /* select the list */
397
421
  if (key < 0) {
398
422
  // not enough centroids for multiprobe
399
423
  continue;
400
424
  }
401
- FAISS_THROW_IF_NOT_FMT
402
- (key < (idx_t) ivf.nlist,
403
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
404
- key, ik, ivf.nlist);
425
+ FAISS_THROW_IF_NOT_FMT(
426
+ key < (idx_t)ivf.nlist,
427
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
428
+ key,
429
+ ik,
430
+ ivf.nlist);
405
431
 
406
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
432
+ scanner->set_list(key, coarse_dis[i * nprobe + ik]);
407
433
 
408
434
  nlistv++;
409
435
 
410
436
  size_t list_size = ivf.invlists->list_size(key);
411
- InvertedLists::ScopedCodes scodes (ivf.invlists, key);
437
+ InvertedLists::ScopedCodes scodes(ivf.invlists, key);
412
438
  std::unique_ptr<InvertedLists::ScopedIds> sids;
413
- const Index::idx_t * ids = nullptr;
439
+ const Index::idx_t* ids = nullptr;
414
440
 
415
441
  if (!store_pairs) {
416
- sids.reset (new InvertedLists::ScopedIds (ivf.invlists, key));
442
+ sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key));
417
443
  ids = sids->get();
418
444
  }
419
445
 
420
- nheap += scanner->scan_codes (
421
- list_size, scodes.get(),
422
- ids, simi, idxi, k
423
- );
446
+ nheap += scanner->scan_codes(
447
+ list_size, scodes.get(), ids, simi, idxi, k);
424
448
 
425
449
  nscan += list_size;
426
450
  if (max_codes && nscan >= max_codes)
@@ -429,208 +453,205 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
429
453
 
430
454
  ndis += nscan;
431
455
  if (metric_type == METRIC_INNER_PRODUCT) {
432
- heap_reorder<HeapForIP> (k, simi, idxi);
456
+ heap_reorder<HeapForIP>(k, simi, idxi);
433
457
  } else {
434
- heap_reorder<HeapForL2> (k, simi, idxi);
458
+ heap_reorder<HeapForL2>(k, simi, idxi);
435
459
  }
436
460
 
437
461
  } // parallel for
438
- } // parallel
462
+ } // parallel
439
463
 
440
464
  indexIVF_stats.nq += n;
441
465
  indexIVF_stats.nlist += nlistv;
442
466
  indexIVF_stats.ndis += ndis;
443
467
  indexIVF_stats.nheap_updates += nheap;
444
-
445
468
  }
446
469
 
447
- template<class HammingComputer, bool store_pairs>
448
- void search_knn_hamming_count(const IndexBinaryIVF& ivf,
449
- size_t nx,
450
- const uint8_t *x,
451
- const idx_t *keys,
452
- int k,
453
- int32_t *distances,
454
- idx_t *labels,
455
- const IVFSearchParameters *params) {
456
- const int nBuckets = ivf.d + 1;
457
- std::vector<int> all_counters(nx * nBuckets, 0);
458
- std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
459
-
460
- long nprobe = params ? params->nprobe : ivf.nprobe;
461
- long max_codes = params ? params->max_codes : ivf.max_codes;
462
-
463
- std::vector<HCounterState<HammingComputer>> cs;
464
- for (size_t i = 0; i < nx; ++i) {
465
- cs.push_back(HCounterState<HammingComputer>(
466
- all_counters.data() + i * nBuckets,
467
- all_ids_per_dis.get() + i * nBuckets * k,
468
- x + i * ivf.code_size,
469
- ivf.d,
470
- k
471
- ));
472
- }
473
-
474
- size_t nlistv = 0, ndis = 0;
475
-
476
- #pragma omp parallel for reduction(+: nlistv, ndis)
477
- for (int64_t i = 0; i < nx; i++) {
478
- const idx_t * keysi = keys + i * nprobe;
479
- HCounterState<HammingComputer>& csi = cs[i];
480
-
481
- size_t nscan = 0;
482
-
483
- for (size_t ik = 0; ik < nprobe; ik++) {
484
- idx_t key = keysi[ik]; /* select the list */
485
- if (key < 0) {
486
- // not enough centroids for multiprobe
487
- continue;
488
- }
489
- FAISS_THROW_IF_NOT_FMT (
490
- key < (idx_t) ivf.nlist,
491
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
492
- key, ik, ivf.nlist);
493
-
494
- nlistv++;
495
- size_t list_size = ivf.invlists->list_size(key);
496
- InvertedLists::ScopedCodes scodes (ivf.invlists, key);
497
- const uint8_t *list_vecs = scodes.get();
498
- const Index::idx_t *ids = store_pairs
499
- ? nullptr
500
- : ivf.invlists->get_ids(key);
501
-
502
- for (size_t j = 0; j < list_size; j++) {
503
- const uint8_t * yj = list_vecs + ivf.code_size * j;
504
-
505
- idx_t id = store_pairs ? (key << 32 | j) : ids[j];
506
- csi.update_counter(yj, id);
507
- }
508
- if (ids)
509
- ivf.invlists->release_ids (key, ids);
510
-
511
- nscan += list_size;
512
- if (max_codes && nscan >= max_codes)
513
- break;
514
- }
515
- ndis += nscan;
516
-
517
- int nres = 0;
518
- for (int b = 0; b < nBuckets && nres < k; b++) {
519
- for (int l = 0; l < csi.counters[b] && nres < k; l++) {
520
- labels[i * k + nres] = csi.ids_per_dis[b * k + l];
521
- distances[i * k + nres] = b;
522
- nres++;
523
- }
524
- }
525
- while (nres < k) {
526
- labels[i * k + nres] = -1;
527
- distances[i * k + nres] = std::numeric_limits<int32_t>::max();
528
- ++nres;
470
+ template <class HammingComputer, bool store_pairs>
471
+ void search_knn_hamming_count(
472
+ const IndexBinaryIVF& ivf,
473
+ size_t nx,
474
+ const uint8_t* x,
475
+ const idx_t* keys,
476
+ int k,
477
+ int32_t* distances,
478
+ idx_t* labels,
479
+ const IVFSearchParameters* params) {
480
+ const int nBuckets = ivf.d + 1;
481
+ std::vector<int> all_counters(nx * nBuckets, 0);
482
+ std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
483
+
484
+ idx_t nprobe = params ? params->nprobe : ivf.nprobe;
485
+ nprobe = std::min((idx_t)ivf.nlist, nprobe);
486
+ idx_t max_codes = params ? params->max_codes : ivf.max_codes;
487
+
488
+ std::vector<HCounterState<HammingComputer>> cs;
489
+ for (size_t i = 0; i < nx; ++i) {
490
+ cs.push_back(HCounterState<HammingComputer>(
491
+ all_counters.data() + i * nBuckets,
492
+ all_ids_per_dis.get() + i * nBuckets * k,
493
+ x + i * ivf.code_size,
494
+ ivf.d,
495
+ k));
529
496
  }
530
- }
531
497
 
532
- indexIVF_stats.nq += nx;
533
- indexIVF_stats.nlist += nlistv;
534
- indexIVF_stats.ndis += ndis;
535
- }
498
+ size_t nlistv = 0, ndis = 0;
536
499
 
500
+ #pragma omp parallel for reduction(+ : nlistv, ndis)
501
+ for (int64_t i = 0; i < nx; i++) {
502
+ const idx_t* keysi = keys + i * nprobe;
503
+ HCounterState<HammingComputer>& csi = cs[i];
537
504
 
505
+ size_t nscan = 0;
538
506
 
539
- template<bool store_pairs>
540
- void search_knn_hamming_count_1 (
541
- const IndexBinaryIVF& ivf,
542
- size_t nx,
543
- const uint8_t *x,
544
- const idx_t *keys,
545
- int k,
546
- int32_t *distances,
547
- idx_t *labels,
548
- const IVFSearchParameters *params) {
549
- switch (ivf.code_size) {
550
- #define HANDLE_CS(cs) \
551
- case cs: \
552
- search_knn_hamming_count<HammingComputer ## cs, store_pairs>( \
553
- ivf, nx, x, keys, k, distances, labels, params); \
554
- break;
555
- HANDLE_CS(4);
556
- HANDLE_CS(8);
557
- HANDLE_CS(16);
558
- HANDLE_CS(20);
559
- HANDLE_CS(32);
560
- HANDLE_CS(64);
561
- #undef HANDLE_CS
562
- default:
563
- if (ivf.code_size % 8 == 0) {
564
- search_knn_hamming_count<HammingComputerM8, store_pairs>
565
- (ivf, nx, x, keys, k, distances, labels, params);
566
- } else if (ivf.code_size % 4 == 0) {
567
- search_knn_hamming_count<HammingComputerM4, store_pairs>
568
- (ivf, nx, x, keys, k, distances, labels, params);
569
- } else {
570
- search_knn_hamming_count<HammingComputerDefault, store_pairs>
571
- (ivf, nx, x, keys, k, distances, labels, params);
507
+ for (size_t ik = 0; ik < nprobe; ik++) {
508
+ idx_t key = keysi[ik]; /* select the list */
509
+ if (key < 0) {
510
+ // not enough centroids for multiprobe
511
+ continue;
512
+ }
513
+ FAISS_THROW_IF_NOT_FMT(
514
+ key < (idx_t)ivf.nlist,
515
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
516
+ key,
517
+ ik,
518
+ ivf.nlist);
519
+
520
+ nlistv++;
521
+ size_t list_size = ivf.invlists->list_size(key);
522
+ InvertedLists::ScopedCodes scodes(ivf.invlists, key);
523
+ const uint8_t* list_vecs = scodes.get();
524
+ const Index::idx_t* ids =
525
+ store_pairs ? nullptr : ivf.invlists->get_ids(key);
526
+
527
+ for (size_t j = 0; j < list_size; j++) {
528
+ const uint8_t* yj = list_vecs + ivf.code_size * j;
529
+
530
+ idx_t id = store_pairs ? (key << 32 | j) : ids[j];
531
+ csi.update_counter(yj, id);
532
+ }
533
+ if (ids)
534
+ ivf.invlists->release_ids(key, ids);
535
+
536
+ nscan += list_size;
537
+ if (max_codes && nscan >= max_codes)
538
+ break;
539
+ }
540
+ ndis += nscan;
541
+
542
+ int nres = 0;
543
+ for (int b = 0; b < nBuckets && nres < k; b++) {
544
+ for (int l = 0; l < csi.counters[b] && nres < k; l++) {
545
+ labels[i * k + nres] = csi.ids_per_dis[b * k + l];
546
+ distances[i * k + nres] = b;
547
+ nres++;
548
+ }
549
+ }
550
+ while (nres < k) {
551
+ labels[i * k + nres] = -1;
552
+ distances[i * k + nres] = std::numeric_limits<int32_t>::max();
553
+ ++nres;
572
554
  }
573
- break;
574
555
  }
575
556
 
557
+ indexIVF_stats.nq += nx;
558
+ indexIVF_stats.nlist += nlistv;
559
+ indexIVF_stats.ndis += ndis;
576
560
  }
577
561
 
578
- } // namespace
562
+ template <bool store_pairs>
563
+ void search_knn_hamming_count_1(
564
+ const IndexBinaryIVF& ivf,
565
+ size_t nx,
566
+ const uint8_t* x,
567
+ const idx_t* keys,
568
+ int k,
569
+ int32_t* distances,
570
+ idx_t* labels,
571
+ const IVFSearchParameters* params) {
572
+ switch (ivf.code_size) {
573
+ #define HANDLE_CS(cs) \
574
+ case cs: \
575
+ search_knn_hamming_count<HammingComputer##cs, store_pairs>( \
576
+ ivf, nx, x, keys, k, distances, labels, params); \
577
+ break;
578
+ HANDLE_CS(4);
579
+ HANDLE_CS(8);
580
+ HANDLE_CS(16);
581
+ HANDLE_CS(20);
582
+ HANDLE_CS(32);
583
+ HANDLE_CS(64);
584
+ #undef HANDLE_CS
585
+ default:
586
+ search_knn_hamming_count<HammingComputerDefault, store_pairs>(
587
+ ivf, nx, x, keys, k, distances, labels, params);
588
+ break;
589
+ }
590
+ }
579
591
 
580
- BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
581
- (bool store_pairs) const
582
- {
592
+ } // namespace
583
593
 
584
- #define HC(name) return new IVFBinaryScannerL2<name> (code_size, store_pairs)
594
+ BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
595
+ bool store_pairs) const {
596
+ #define HC(name) return new IVFBinaryScannerL2<name>(code_size, store_pairs)
585
597
  switch (code_size) {
586
- case 4: HC(HammingComputer4);
587
- case 8: HC(HammingComputer8);
588
- case 16: HC(HammingComputer16);
589
- case 20: HC(HammingComputer20);
590
- case 32: HC(HammingComputer32);
591
- case 64: HC(HammingComputer64);
592
- default:
593
- if (code_size % 8 == 0) {
594
- HC(HammingComputerM8);
595
- } else if (code_size % 4 == 0) {
596
- HC(HammingComputerM4);
597
- } else {
598
+ case 4:
599
+ HC(HammingComputer4);
600
+ case 8:
601
+ HC(HammingComputer8);
602
+ case 16:
603
+ HC(HammingComputer16);
604
+ case 20:
605
+ HC(HammingComputer20);
606
+ case 32:
607
+ HC(HammingComputer32);
608
+ case 64:
609
+ HC(HammingComputer64);
610
+ default:
598
611
  HC(HammingComputerDefault);
599
- }
600
612
  }
601
613
  #undef HC
602
-
603
614
  }
604
615
 
605
- void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
606
- const idx_t *idx,
607
- const int32_t * coarse_dis,
608
- int32_t *distances, idx_t *labels,
609
- bool store_pairs,
610
- const IVFSearchParameters *params
611
- ) const {
612
-
616
+ void IndexBinaryIVF::search_preassigned(
617
+ idx_t n,
618
+ const uint8_t* x,
619
+ idx_t k,
620
+ const idx_t* idx,
621
+ const int32_t* coarse_dis,
622
+ int32_t* distances,
623
+ idx_t* labels,
624
+ bool store_pairs,
625
+ const IVFSearchParameters* params) const {
613
626
  if (use_heap) {
614
- search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
615
- distances, labels, store_pairs,
616
- params);
627
+ search_knn_hamming_heap(
628
+ *this,
629
+ n,
630
+ x,
631
+ k,
632
+ idx,
633
+ coarse_dis,
634
+ distances,
635
+ labels,
636
+ store_pairs,
637
+ params);
617
638
  } else {
618
639
  if (store_pairs) {
619
- search_knn_hamming_count_1<true>
620
- (*this, n, x, idx, k, distances, labels, params);
640
+ search_knn_hamming_count_1<true>(
641
+ *this, n, x, idx, k, distances, labels, params);
621
642
  } else {
622
- search_knn_hamming_count_1<false>
623
- (*this, n, x, idx, k, distances, labels, params);
643
+ search_knn_hamming_count_1<false>(
644
+ *this, n, x, idx, k, distances, labels, params);
624
645
  }
625
646
  }
626
647
  }
627
648
 
628
-
629
649
  void IndexBinaryIVF::range_search(
630
- idx_t n, const uint8_t *x, int radius,
631
- RangeSearchResult *res) const
632
- {
633
-
650
+ idx_t n,
651
+ const uint8_t* x,
652
+ int radius,
653
+ RangeSearchResult* res) const {
654
+ const size_t nprobe = std::min(nlist, this->nprobe);
634
655
  std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
635
656
  std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
636
657
 
@@ -641,77 +662,84 @@ void IndexBinaryIVF::range_search(
641
662
  t0 = getmillisecs();
642
663
  invlists->prefetch_lists(idx.get(), n * nprobe);
643
664
 
665
+ range_search_preassigned(n, x, radius, idx.get(), coarse_dis.get(), res);
666
+
667
+ indexIVF_stats.search_time += getmillisecs() - t0;
668
+ }
669
+
670
+ void IndexBinaryIVF::range_search_preassigned(
671
+ idx_t n,
672
+ const uint8_t* x,
673
+ int radius,
674
+ const idx_t* assign,
675
+ const int32_t* centroid_dis,
676
+ RangeSearchResult* res) const {
677
+ const size_t nprobe = std::min(nlist, this->nprobe);
644
678
  bool store_pairs = false;
645
679
  size_t nlistv = 0, ndis = 0;
646
680
 
647
- std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
681
+ std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
648
682
 
649
- #pragma omp parallel reduction(+: nlistv, ndis)
683
+ #pragma omp parallel reduction(+ : nlistv, ndis)
650
684
  {
651
685
  RangeSearchPartialResult pres(res);
652
- std::unique_ptr<BinaryInvertedListScanner> scanner
653
- (get_InvertedListScanner(store_pairs));
654
- FAISS_THROW_IF_NOT (scanner.get ());
686
+ std::unique_ptr<BinaryInvertedListScanner> scanner(
687
+ get_InvertedListScanner(store_pairs));
688
+ FAISS_THROW_IF_NOT(scanner.get());
655
689
 
656
690
  all_pres[omp_get_thread_num()] = &pres;
657
691
 
658
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres)
659
- {
660
-
661
- idx_t key = idx[i * nprobe + ik]; /* select the list */
662
- if (key < 0) return;
663
- FAISS_THROW_IF_NOT_FMT (
664
- key < (idx_t) nlist,
692
+ auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
693
+ idx_t key = assign[i * nprobe + ik]; /* select the list */
694
+ if (key < 0)
695
+ return;
696
+ FAISS_THROW_IF_NOT_FMT(
697
+ key < (idx_t)nlist,
665
698
  "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
666
- key, ik, nlist);
699
+ key,
700
+ ik,
701
+ nlist);
667
702
  const size_t list_size = invlists->list_size(key);
668
703
 
669
- if (list_size == 0) return;
704
+ if (list_size == 0)
705
+ return;
670
706
 
671
- InvertedLists::ScopedCodes scodes (invlists, key);
672
- InvertedLists::ScopedIds ids (invlists, key);
707
+ InvertedLists::ScopedCodes scodes(invlists, key);
708
+ InvertedLists::ScopedIds ids(invlists, key);
673
709
 
674
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
710
+ scanner->set_list(key, assign[i * nprobe + ik]);
675
711
  nlistv++;
676
712
  ndis += list_size;
677
- scanner->scan_codes_range (list_size, scodes.get(),
678
- ids.get(), radius, qres);
713
+ scanner->scan_codes_range(
714
+ list_size, scodes.get(), ids.get(), radius, qres);
679
715
  };
680
716
 
681
717
  #pragma omp for
682
718
  for (idx_t i = 0; i < n; i++) {
683
- scanner->set_query (x + i * code_size);
719
+ scanner->set_query(x + i * code_size);
684
720
 
685
- RangeQueryResult & qres = pres.new_result (i);
721
+ RangeQueryResult& qres = pres.new_result(i);
686
722
 
687
723
  for (size_t ik = 0; ik < nprobe; ik++) {
688
- scan_list_func (i, ik, qres);
724
+ scan_list_func(i, ik, qres);
689
725
  }
690
-
691
726
  }
692
727
 
693
728
  pres.finalize();
694
-
695
729
  }
696
730
  indexIVF_stats.nq += n;
697
731
  indexIVF_stats.nlist += nlistv;
698
732
  indexIVF_stats.ndis += ndis;
699
- indexIVF_stats.search_time += getmillisecs() - t0;
700
-
701
733
  }
702
734
 
703
-
704
-
705
-
706
735
  IndexBinaryIVF::~IndexBinaryIVF() {
707
- if (own_invlists) {
708
- delete invlists;
709
- }
736
+ if (own_invlists) {
737
+ delete invlists;
738
+ }
710
739
 
711
- if (own_fields) {
712
- delete quantizer;
713
- }
740
+ if (own_fields) {
741
+ delete quantizer;
742
+ }
714
743
  }
715
744
 
716
-
717
- } // namespace faiss
745
+ } // namespace faiss