faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -9,318 +9,344 @@
9
9
 
10
10
  #include <faiss/IndexBinaryIVF.h>
11
11
 
12
+ #include <omp.h>
12
13
  #include <cinttypes>
13
14
  #include <cstdio>
14
- #include <omp.h>
15
15
 
16
+ #include <algorithm>
16
17
  #include <memory>
17
18
 
18
-
19
- #include <faiss/utils/hamming.h>
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/impl/AuxIndexStructures.h>
22
- #include <faiss/impl/FaissAssert.h>
23
19
  #include <faiss/IndexFlat.h>
24
20
  #include <faiss/IndexLSH.h>
25
-
21
+ #include <faiss/impl/AuxIndexStructures.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/utils.h>
26
25
 
27
26
  namespace faiss {
28
27
 
29
- IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist)
30
- : IndexBinary(d),
31
- invlists(new ArrayInvertedLists(nlist, code_size)),
32
- own_invlists(true),
33
- nprobe(1),
34
- max_codes(0),
35
- quantizer(quantizer),
36
- nlist(nlist),
37
- own_fields(false),
38
- clustering_index(nullptr)
39
- {
40
- FAISS_THROW_IF_NOT (d == quantizer->d);
41
- is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
42
-
43
- cp.niter = 10;
28
+ IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
29
+ : IndexBinary(d),
30
+ invlists(new ArrayInvertedLists(nlist, code_size)),
31
+ own_invlists(true),
32
+ nprobe(1),
33
+ max_codes(0),
34
+ quantizer(quantizer),
35
+ nlist(nlist),
36
+ own_fields(false),
37
+ clustering_index(nullptr) {
38
+ FAISS_THROW_IF_NOT(d == quantizer->d);
39
+ is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
40
+
41
+ cp.niter = 10;
44
42
  }
45
43
 
46
44
  IndexBinaryIVF::IndexBinaryIVF()
47
- : invlists(nullptr),
48
- own_invlists(false),
49
- nprobe(1),
50
- max_codes(0),
51
- quantizer(nullptr),
52
- nlist(0),
53
- own_fields(false),
54
- clustering_index(nullptr)
55
- {}
56
-
57
- void IndexBinaryIVF::add(idx_t n, const uint8_t *x) {
58
- add_with_ids(n, x, nullptr);
45
+ : invlists(nullptr),
46
+ own_invlists(false),
47
+ nprobe(1),
48
+ max_codes(0),
49
+ quantizer(nullptr),
50
+ nlist(0),
51
+ own_fields(false),
52
+ clustering_index(nullptr) {}
53
+
54
+ void IndexBinaryIVF::add(idx_t n, const uint8_t* x) {
55
+ add_with_ids(n, x, nullptr);
59
56
  }
60
57
 
61
- void IndexBinaryIVF::add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) {
62
- add_core(n, x, xids, nullptr);
58
+ void IndexBinaryIVF::add_with_ids(
59
+ idx_t n,
60
+ const uint8_t* x,
61
+ const idx_t* xids) {
62
+ add_core(n, x, xids, nullptr);
63
63
  }
64
64
 
65
- void IndexBinaryIVF::add_core(idx_t n, const uint8_t *x, const idx_t *xids,
66
- const idx_t *precomputed_idx) {
67
- FAISS_THROW_IF_NOT(is_trained);
68
- assert(invlists);
69
- direct_map.check_can_add (xids);
65
+ void IndexBinaryIVF::add_core(
66
+ idx_t n,
67
+ const uint8_t* x,
68
+ const idx_t* xids,
69
+ const idx_t* precomputed_idx) {
70
+ FAISS_THROW_IF_NOT(is_trained);
71
+ assert(invlists);
72
+ direct_map.check_can_add(xids);
70
73
 
71
- const idx_t * idx;
74
+ const idx_t* idx;
72
75
 
73
- std::unique_ptr<idx_t[]> scoped_idx;
76
+ std::unique_ptr<idx_t[]> scoped_idx;
74
77
 
75
- if (precomputed_idx) {
76
- idx = precomputed_idx;
77
- } else {
78
- scoped_idx.reset(new idx_t[n]);
79
- quantizer->assign(n, x, scoped_idx.get());
80
- idx = scoped_idx.get();
81
- }
78
+ if (precomputed_idx) {
79
+ idx = precomputed_idx;
80
+ } else {
81
+ scoped_idx.reset(new idx_t[n]);
82
+ quantizer->assign(n, x, scoped_idx.get());
83
+ idx = scoped_idx.get();
84
+ }
82
85
 
83
- long n_add = 0;
84
- for (size_t i = 0; i < n; i++) {
85
- idx_t id = xids ? xids[i] : ntotal + i;
86
- idx_t list_no = idx[i];
86
+ idx_t n_add = 0;
87
+ for (size_t i = 0; i < n; i++) {
88
+ idx_t id = xids ? xids[i] : ntotal + i;
89
+ idx_t list_no = idx[i];
87
90
 
88
- if (list_no < 0) {
89
- direct_map.add_single_id (id, -1, 0);
90
- } else {
91
- const uint8_t *xi = x + i * code_size;
92
- size_t offset = invlists->add_entry(list_no, id, xi);
91
+ if (list_no < 0) {
92
+ direct_map.add_single_id(id, -1, 0);
93
+ } else {
94
+ const uint8_t* xi = x + i * code_size;
95
+ size_t offset = invlists->add_entry(list_no, id, xi);
93
96
 
94
- direct_map.add_single_id (id, list_no, offset);
95
- }
97
+ direct_map.add_single_id(id, list_no, offset);
98
+ }
96
99
 
97
- n_add++;
98
- }
99
- if (verbose) {
100
- printf("IndexBinaryIVF::add_with_ids: added %ld / %" PRId64 " vectors\n",
101
- n_add, n);
102
- }
103
- ntotal += n_add;
100
+ n_add++;
101
+ }
102
+ if (verbose) {
103
+ printf("IndexBinaryIVF::add_with_ids: added "
104
+ "%" PRId64 " / %" PRId64 " vectors\n",
105
+ n_add,
106
+ n);
107
+ }
108
+ ntotal += n_add;
104
109
  }
105
110
 
106
- void IndexBinaryIVF::make_direct_map (bool b)
107
- {
111
+ void IndexBinaryIVF::make_direct_map(bool b) {
108
112
  if (b) {
109
- direct_map.set_type (DirectMap::Array, invlists, ntotal);
113
+ direct_map.set_type(DirectMap::Array, invlists, ntotal);
110
114
  } else {
111
- direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
115
+ direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
112
116
  }
113
117
  }
114
118
 
115
- void IndexBinaryIVF::set_direct_map_type (DirectMap::Type type)
116
- {
117
- direct_map.set_type (type, invlists, ntotal);
119
+ void IndexBinaryIVF::set_direct_map_type(DirectMap::Type type) {
120
+ direct_map.set_type(type, invlists, ntotal);
118
121
  }
119
122
 
123
+ void IndexBinaryIVF::search(
124
+ idx_t n,
125
+ const uint8_t* x,
126
+ idx_t k,
127
+ int32_t* distances,
128
+ idx_t* labels) const {
129
+ FAISS_THROW_IF_NOT(k > 0);
130
+ FAISS_THROW_IF_NOT(nprobe > 0);
120
131
 
121
- void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k,
122
- int32_t *distances, idx_t *labels) const {
123
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
124
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
132
+ const size_t nprobe = std::min(nlist, this->nprobe);
133
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
134
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
125
135
 
126
- double t0 = getmillisecs();
127
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
128
- indexIVF_stats.quantization_time += getmillisecs() - t0;
136
+ double t0 = getmillisecs();
137
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
138
+ indexIVF_stats.quantization_time += getmillisecs() - t0;
129
139
 
130
- t0 = getmillisecs();
131
- invlists->prefetch_lists(idx.get(), n * nprobe);
140
+ t0 = getmillisecs();
141
+ invlists->prefetch_lists(idx.get(), n * nprobe);
132
142
 
133
- search_preassigned(n, x, k, idx.get(), coarse_dis.get(),
134
- distances, labels, false);
135
- indexIVF_stats.search_time += getmillisecs() - t0;
143
+ search_preassigned(
144
+ n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
145
+ indexIVF_stats.search_time += getmillisecs() - t0;
136
146
  }
137
147
 
138
- void IndexBinaryIVF::reconstruct(idx_t key, uint8_t *recons) const {
139
- idx_t lo = direct_map.get (key);
140
- reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
148
+ void IndexBinaryIVF::reconstruct(idx_t key, uint8_t* recons) const {
149
+ idx_t lo = direct_map.get(key);
150
+ reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
141
151
  }
142
152
 
143
- void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t *recons) const {
144
- FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
153
+ void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
154
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
145
155
 
146
- for (idx_t list_no = 0; list_no < nlist; list_no++) {
147
- size_t list_size = invlists->list_size(list_no);
148
- const Index::idx_t *idlist = invlists->get_ids(list_no);
156
+ for (idx_t list_no = 0; list_no < nlist; list_no++) {
157
+ size_t list_size = invlists->list_size(list_no);
158
+ const Index::idx_t* idlist = invlists->get_ids(list_no);
149
159
 
150
- for (idx_t offset = 0; offset < list_size; offset++) {
151
- idx_t id = idlist[offset];
152
- if (!(id >= i0 && id < i0 + ni)) {
153
- continue;
154
- }
160
+ for (idx_t offset = 0; offset < list_size; offset++) {
161
+ idx_t id = idlist[offset];
162
+ if (!(id >= i0 && id < i0 + ni)) {
163
+ continue;
164
+ }
155
165
 
156
- uint8_t *reconstructed = recons + (id - i0) * d;
157
- reconstruct_from_offset(list_no, offset, reconstructed);
166
+ uint8_t* reconstructed = recons + (id - i0) * d;
167
+ reconstruct_from_offset(list_no, offset, reconstructed);
168
+ }
158
169
  }
159
- }
160
170
  }
161
171
 
162
- void IndexBinaryIVF::search_and_reconstruct(idx_t n, const uint8_t *x, idx_t k,
163
- int32_t *distances, idx_t *labels,
164
- uint8_t *recons) const {
165
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
166
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
167
-
168
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
169
-
170
- invlists->prefetch_lists(idx.get(), n * nprobe);
171
-
172
- // search_preassigned() with `store_pairs` enabled to obtain the list_no
173
- // and offset into `codes` for reconstruction
174
- search_preassigned(n, x, k, idx.get(), coarse_dis.get(),
175
- distances, labels, /* store_pairs */true);
176
- for (idx_t i = 0; i < n; ++i) {
177
- for (idx_t j = 0; j < k; ++j) {
178
- idx_t ij = i * k + j;
179
- idx_t key = labels[ij];
180
- uint8_t *reconstructed = recons + ij * d;
181
- if (key < 0) {
182
- // Fill with NaNs
183
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
184
- } else {
185
- int list_no = key >> 32;
186
- int offset = key & 0xffffffff;
187
-
188
- // Update label to the actual id
189
- labels[ij] = invlists->get_single_id(list_no, offset);
190
-
191
- reconstruct_from_offset(list_no, offset, reconstructed);
192
- }
172
+ void IndexBinaryIVF::search_and_reconstruct(
173
+ idx_t n,
174
+ const uint8_t* x,
175
+ idx_t k,
176
+ int32_t* distances,
177
+ idx_t* labels,
178
+ uint8_t* recons) const {
179
+ const size_t nprobe = std::min(nlist, this->nprobe);
180
+ FAISS_THROW_IF_NOT(k > 0);
181
+ FAISS_THROW_IF_NOT(nprobe > 0);
182
+
183
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
184
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
185
+
186
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
187
+
188
+ invlists->prefetch_lists(idx.get(), n * nprobe);
189
+
190
+ // search_preassigned() with `store_pairs` enabled to obtain the list_no
191
+ // and offset into `codes` for reconstruction
192
+ search_preassigned(
193
+ n,
194
+ x,
195
+ k,
196
+ idx.get(),
197
+ coarse_dis.get(),
198
+ distances,
199
+ labels,
200
+ /* store_pairs */ true);
201
+ for (idx_t i = 0; i < n; ++i) {
202
+ for (idx_t j = 0; j < k; ++j) {
203
+ idx_t ij = i * k + j;
204
+ idx_t key = labels[ij];
205
+ uint8_t* reconstructed = recons + ij * d;
206
+ if (key < 0) {
207
+ // Fill with NaNs
208
+ memset(reconstructed, -1, sizeof(*reconstructed) * d);
209
+ } else {
210
+ int list_no = key >> 32;
211
+ int offset = key & 0xffffffff;
212
+
213
+ // Update label to the actual id
214
+ labels[ij] = invlists->get_single_id(list_no, offset);
215
+
216
+ reconstruct_from_offset(list_no, offset, reconstructed);
217
+ }
218
+ }
193
219
  }
194
- }
195
220
  }
196
221
 
197
- void IndexBinaryIVF::reconstruct_from_offset(idx_t list_no, idx_t offset,
198
- uint8_t *recons) const {
199
- memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
222
+ void IndexBinaryIVF::reconstruct_from_offset(
223
+ idx_t list_no,
224
+ idx_t offset,
225
+ uint8_t* recons) const {
226
+ memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
200
227
  }
201
228
 
202
229
  void IndexBinaryIVF::reset() {
203
- direct_map.clear();
204
- invlists->reset();
205
- ntotal = 0;
230
+ direct_map.clear();
231
+ invlists->reset();
232
+ ntotal = 0;
206
233
  }
207
234
 
208
235
  size_t IndexBinaryIVF::remove_ids(const IDSelector& sel) {
209
- size_t nremove = direct_map.remove_ids (sel, invlists);
236
+ size_t nremove = direct_map.remove_ids(sel, invlists);
210
237
  ntotal -= nremove;
211
238
  return nremove;
212
239
  }
213
240
 
214
- void IndexBinaryIVF::train(idx_t n, const uint8_t *x) {
215
- if (verbose) {
216
- printf("Training quantizer\n");
217
- }
218
-
219
- if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
220
- if (verbose) {
221
- printf("IVF quantizer does not need training.\n");
222
- }
223
- } else {
241
+ void IndexBinaryIVF::train(idx_t n, const uint8_t* x) {
224
242
  if (verbose) {
225
- printf("Training quantizer on %" PRId64 " vectors in %dD\n", n, d);
243
+ printf("Training quantizer\n");
226
244
  }
227
245
 
228
- Clustering clus(d, nlist, cp);
229
- quantizer->reset();
230
-
231
- IndexFlatL2 index_tmp(d);
232
-
233
- if (clustering_index && verbose) {
234
- printf("using clustering_index of dimension %d to do the clustering\n",
235
- clustering_index->d);
236
- }
246
+ if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
247
+ if (verbose) {
248
+ printf("IVF quantizer does not need training.\n");
249
+ }
250
+ } else {
251
+ if (verbose) {
252
+ printf("Training quantizer on %" PRId64 " vectors in %dD\n", n, d);
253
+ }
237
254
 
238
- // LSH codec that is able to convert the binary vectors to floats.
239
- IndexLSH codec(d, d, false, false);
255
+ Clustering clus(d, nlist, cp);
256
+ quantizer->reset();
240
257
 
241
- clus.train_encoded (n, x, &codec, clustering_index ? *clustering_index : index_tmp);
258
+ IndexFlatL2 index_tmp(d);
242
259
 
243
- // convert clusters to binary
244
- std::unique_ptr<uint8_t[]> x_b(new uint8_t[clus.k * code_size]);
245
- real_to_binary(d * clus.k, clus.centroids.data(), x_b.get());
260
+ if (clustering_index && verbose) {
261
+ printf("using clustering_index of dimension %d to do the clustering\n",
262
+ clustering_index->d);
263
+ }
246
264
 
247
- quantizer->add(clus.k, x_b.get());
248
- quantizer->is_trained = true;
249
- }
265
+ // LSH codec that is able to convert the binary vectors to floats.
266
+ IndexLSH codec(d, d, false, false);
250
267
 
251
- is_trained = true;
252
- }
268
+ clus.train_encoded(
269
+ n, x, &codec, clustering_index ? *clustering_index : index_tmp);
253
270
 
254
- void IndexBinaryIVF::merge_from(IndexBinaryIVF &other, idx_t add_id) {
255
- // minimal sanity checks
256
- FAISS_THROW_IF_NOT(other.d == d);
257
- FAISS_THROW_IF_NOT(other.nlist == nlist);
258
- FAISS_THROW_IF_NOT(other.code_size == code_size);
259
- FAISS_THROW_IF_NOT_MSG(direct_map.no() && other.direct_map.no(),
260
- "direct map copy not implemented");
261
- FAISS_THROW_IF_NOT_MSG(typeid (*this) == typeid (other),
262
- "can only merge indexes of the same type");
271
+ // convert clusters to binary
272
+ std::unique_ptr<uint8_t[]> x_b(new uint8_t[clus.k * code_size]);
273
+ real_to_binary(d * clus.k, clus.centroids.data(), x_b.get());
263
274
 
264
- invlists->merge_from (other.invlists, add_id);
275
+ quantizer->add(clus.k, x_b.get());
276
+ quantizer->is_trained = true;
277
+ }
265
278
 
266
- ntotal += other.ntotal;
267
- other.ntotal = 0;
279
+ is_trained = true;
268
280
  }
269
281
 
270
- void IndexBinaryIVF::replace_invlists(InvertedLists *il, bool own) {
271
- FAISS_THROW_IF_NOT(il->nlist == nlist &&
272
- il->code_size == code_size);
273
- if (own_invlists) {
274
- delete invlists;
275
- }
276
- invlists = il;
277
- own_invlists = own;
282
+ void IndexBinaryIVF::merge_from(IndexBinaryIVF& other, idx_t add_id) {
283
+ // minimal sanity checks
284
+ FAISS_THROW_IF_NOT(other.d == d);
285
+ FAISS_THROW_IF_NOT(other.nlist == nlist);
286
+ FAISS_THROW_IF_NOT(other.code_size == code_size);
287
+ FAISS_THROW_IF_NOT_MSG(
288
+ direct_map.no() && other.direct_map.no(),
289
+ "direct map copy not implemented");
290
+ FAISS_THROW_IF_NOT_MSG(
291
+ typeid(*this) == typeid(other),
292
+ "can only merge indexes of the same type");
293
+
294
+ invlists->merge_from(other.invlists, add_id);
295
+
296
+ ntotal += other.ntotal;
297
+ other.ntotal = 0;
278
298
  }
279
299
 
300
+ void IndexBinaryIVF::replace_invlists(InvertedLists* il, bool own) {
301
+ FAISS_THROW_IF_NOT(il->nlist == nlist && il->code_size == code_size);
302
+ if (own_invlists) {
303
+ delete invlists;
304
+ }
305
+ invlists = il;
306
+ own_invlists = own;
307
+ }
280
308
 
281
309
  namespace {
282
310
 
283
311
  using idx_t = Index::idx_t;
284
312
 
285
-
286
- template<class HammingComputer>
287
- struct IVFBinaryScannerL2: BinaryInvertedListScanner {
288
-
313
+ template <class HammingComputer>
314
+ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
289
315
  HammingComputer hc;
290
316
  size_t code_size;
291
317
  bool store_pairs;
292
318
 
293
- IVFBinaryScannerL2 (size_t code_size, bool store_pairs):
294
- code_size (code_size), store_pairs(store_pairs)
295
- {}
319
+ IVFBinaryScannerL2(size_t code_size, bool store_pairs)
320
+ : code_size(code_size), store_pairs(store_pairs) {}
296
321
 
297
- void set_query (const uint8_t *query_vector) override {
298
- hc.set (query_vector, code_size);
322
+ void set_query(const uint8_t* query_vector) override {
323
+ hc.set(query_vector, code_size);
299
324
  }
300
325
 
301
326
  idx_t list_no;
302
- void set_list (idx_t list_no, uint8_t /* coarse_dis */) override {
327
+ void set_list(idx_t list_no, uint8_t /* coarse_dis */) override {
303
328
  this->list_no = list_no;
304
329
  }
305
330
 
306
- uint32_t distance_to_code (const uint8_t *code) const override {
307
- return hc.hamming (code);
331
+ uint32_t distance_to_code(const uint8_t* code) const override {
332
+ return hc.hamming(code);
308
333
  }
309
334
 
310
- size_t scan_codes (size_t n,
311
- const uint8_t *codes,
312
- const idx_t *ids,
313
- int32_t *simi, idx_t *idxi,
314
- size_t k) const override
315
- {
335
+ size_t scan_codes(
336
+ size_t n,
337
+ const uint8_t* codes,
338
+ const idx_t* ids,
339
+ int32_t* simi,
340
+ idx_t* idxi,
341
+ size_t k) const override {
316
342
  using C = CMax<int32_t, idx_t>;
317
343
 
318
344
  size_t nup = 0;
319
345
  for (size_t j = 0; j < n; j++) {
320
- uint32_t dis = hc.hamming (codes);
346
+ uint32_t dis = hc.hamming(codes);
321
347
  if (dis < simi[0]) {
322
348
  idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
323
- heap_replace_top<C> (k, simi, idxi, dis, id);
349
+ heap_replace_top<C>(k, simi, idxi, dis, id);
324
350
  nup++;
325
351
  }
326
352
  codes += code_size;
@@ -328,40 +354,38 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
328
354
  return nup;
329
355
  }
330
356
 
331
- void scan_codes_range (size_t n,
332
- const uint8_t *codes,
333
- const idx_t *ids,
334
- int radius,
335
- RangeQueryResult &result) const override
336
- {
357
+ void scan_codes_range(
358
+ size_t n,
359
+ const uint8_t* codes,
360
+ const idx_t* ids,
361
+ int radius,
362
+ RangeQueryResult& result) const override {
337
363
  size_t nup = 0;
338
364
  for (size_t j = 0; j < n; j++) {
339
- uint32_t dis = hc.hamming (codes);
365
+ uint32_t dis = hc.hamming(codes);
340
366
  if (dis < radius) {
341
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
342
- result.add (dis, id);
367
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
368
+ result.add(dis, id);
343
369
  }
344
370
  codes += code_size;
345
371
  }
346
-
347
372
  }
348
-
349
-
350
373
  };
351
374
 
352
-
353
- void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
354
- size_t n,
355
- const uint8_t *x,
356
- idx_t k,
357
- const idx_t *keys,
358
- const int32_t * coarse_dis,
359
- int32_t *distances, idx_t *labels,
360
- bool store_pairs,
361
- const IVFSearchParameters *params)
362
- {
363
- long nprobe = params ? params->nprobe : ivf.nprobe;
364
- long max_codes = params ? params->max_codes : ivf.max_codes;
375
+ void search_knn_hamming_heap(
376
+ const IndexBinaryIVF& ivf,
377
+ size_t n,
378
+ const uint8_t* x,
379
+ idx_t k,
380
+ const idx_t* keys,
381
+ const int32_t* coarse_dis,
382
+ int32_t* distances,
383
+ idx_t* labels,
384
+ bool store_pairs,
385
+ const IVFSearchParameters* params) {
386
+ idx_t nprobe = params ? params->nprobe : ivf.nprobe;
387
+ nprobe = std::min((idx_t)ivf.nlist, nprobe);
388
+ idx_t max_codes = params ? params->max_codes : ivf.max_codes;
365
389
  MetricType metric_type = ivf.metric_type;
366
390
 
367
391
  // almost verbatim copy from IndexIVF::search_preassigned
@@ -370,57 +394,57 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
370
394
  using HeapForIP = CMin<int32_t, idx_t>;
371
395
  using HeapForL2 = CMax<int32_t, idx_t>;
372
396
 
373
- #pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap)
397
+ #pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap)
374
398
  {
375
- std::unique_ptr<BinaryInvertedListScanner> scanner
376
- (ivf.get_InvertedListScanner (store_pairs));
399
+ std::unique_ptr<BinaryInvertedListScanner> scanner(
400
+ ivf.get_InvertedListScanner(store_pairs));
377
401
 
378
402
  #pragma omp for
379
403
  for (idx_t i = 0; i < n; i++) {
380
- const uint8_t *xi = x + i * ivf.code_size;
404
+ const uint8_t* xi = x + i * ivf.code_size;
381
405
  scanner->set_query(xi);
382
406
 
383
- const idx_t * keysi = keys + i * nprobe;
384
- int32_t * simi = distances + k * i;
385
- idx_t * idxi = labels + k * i;
407
+ const idx_t* keysi = keys + i * nprobe;
408
+ int32_t* simi = distances + k * i;
409
+ idx_t* idxi = labels + k * i;
386
410
 
387
411
  if (metric_type == METRIC_INNER_PRODUCT) {
388
- heap_heapify<HeapForIP> (k, simi, idxi);
412
+ heap_heapify<HeapForIP>(k, simi, idxi);
389
413
  } else {
390
- heap_heapify<HeapForL2> (k, simi, idxi);
414
+ heap_heapify<HeapForL2>(k, simi, idxi);
391
415
  }
392
416
 
393
417
  size_t nscan = 0;
394
418
 
395
419
  for (size_t ik = 0; ik < nprobe; ik++) {
396
- idx_t key = keysi[ik]; /* select the list */
420
+ idx_t key = keysi[ik]; /* select the list */
397
421
  if (key < 0) {
398
422
  // not enough centroids for multiprobe
399
423
  continue;
400
424
  }
401
- FAISS_THROW_IF_NOT_FMT
402
- (key < (idx_t) ivf.nlist,
403
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
404
- key, ik, ivf.nlist);
425
+ FAISS_THROW_IF_NOT_FMT(
426
+ key < (idx_t)ivf.nlist,
427
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
428
+ key,
429
+ ik,
430
+ ivf.nlist);
405
431
 
406
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
432
+ scanner->set_list(key, coarse_dis[i * nprobe + ik]);
407
433
 
408
434
  nlistv++;
409
435
 
410
436
  size_t list_size = ivf.invlists->list_size(key);
411
- InvertedLists::ScopedCodes scodes (ivf.invlists, key);
437
+ InvertedLists::ScopedCodes scodes(ivf.invlists, key);
412
438
  std::unique_ptr<InvertedLists::ScopedIds> sids;
413
- const Index::idx_t * ids = nullptr;
439
+ const Index::idx_t* ids = nullptr;
414
440
 
415
441
  if (!store_pairs) {
416
- sids.reset (new InvertedLists::ScopedIds (ivf.invlists, key));
442
+ sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key));
417
443
  ids = sids->get();
418
444
  }
419
445
 
420
- nheap += scanner->scan_codes (
421
- list_size, scodes.get(),
422
- ids, simi, idxi, k
423
- );
446
+ nheap += scanner->scan_codes(
447
+ list_size, scodes.get(), ids, simi, idxi, k);
424
448
 
425
449
  nscan += list_size;
426
450
  if (max_codes && nscan >= max_codes)
@@ -429,208 +453,205 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
429
453
 
430
454
  ndis += nscan;
431
455
  if (metric_type == METRIC_INNER_PRODUCT) {
432
- heap_reorder<HeapForIP> (k, simi, idxi);
456
+ heap_reorder<HeapForIP>(k, simi, idxi);
433
457
  } else {
434
- heap_reorder<HeapForL2> (k, simi, idxi);
458
+ heap_reorder<HeapForL2>(k, simi, idxi);
435
459
  }
436
460
 
437
461
  } // parallel for
438
- } // parallel
462
+ } // parallel
439
463
 
440
464
  indexIVF_stats.nq += n;
441
465
  indexIVF_stats.nlist += nlistv;
442
466
  indexIVF_stats.ndis += ndis;
443
467
  indexIVF_stats.nheap_updates += nheap;
444
-
445
468
  }
446
469
 
447
- template<class HammingComputer, bool store_pairs>
448
- void search_knn_hamming_count(const IndexBinaryIVF& ivf,
449
- size_t nx,
450
- const uint8_t *x,
451
- const idx_t *keys,
452
- int k,
453
- int32_t *distances,
454
- idx_t *labels,
455
- const IVFSearchParameters *params) {
456
- const int nBuckets = ivf.d + 1;
457
- std::vector<int> all_counters(nx * nBuckets, 0);
458
- std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
459
-
460
- long nprobe = params ? params->nprobe : ivf.nprobe;
461
- long max_codes = params ? params->max_codes : ivf.max_codes;
462
-
463
- std::vector<HCounterState<HammingComputer>> cs;
464
- for (size_t i = 0; i < nx; ++i) {
465
- cs.push_back(HCounterState<HammingComputer>(
466
- all_counters.data() + i * nBuckets,
467
- all_ids_per_dis.get() + i * nBuckets * k,
468
- x + i * ivf.code_size,
469
- ivf.d,
470
- k
471
- ));
472
- }
473
-
474
- size_t nlistv = 0, ndis = 0;
475
-
476
- #pragma omp parallel for reduction(+: nlistv, ndis)
477
- for (int64_t i = 0; i < nx; i++) {
478
- const idx_t * keysi = keys + i * nprobe;
479
- HCounterState<HammingComputer>& csi = cs[i];
480
-
481
- size_t nscan = 0;
482
-
483
- for (size_t ik = 0; ik < nprobe; ik++) {
484
- idx_t key = keysi[ik]; /* select the list */
485
- if (key < 0) {
486
- // not enough centroids for multiprobe
487
- continue;
488
- }
489
- FAISS_THROW_IF_NOT_FMT (
490
- key < (idx_t) ivf.nlist,
491
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
492
- key, ik, ivf.nlist);
493
-
494
- nlistv++;
495
- size_t list_size = ivf.invlists->list_size(key);
496
- InvertedLists::ScopedCodes scodes (ivf.invlists, key);
497
- const uint8_t *list_vecs = scodes.get();
498
- const Index::idx_t *ids = store_pairs
499
- ? nullptr
500
- : ivf.invlists->get_ids(key);
501
-
502
- for (size_t j = 0; j < list_size; j++) {
503
- const uint8_t * yj = list_vecs + ivf.code_size * j;
504
-
505
- idx_t id = store_pairs ? (key << 32 | j) : ids[j];
506
- csi.update_counter(yj, id);
507
- }
508
- if (ids)
509
- ivf.invlists->release_ids (key, ids);
510
-
511
- nscan += list_size;
512
- if (max_codes && nscan >= max_codes)
513
- break;
514
- }
515
- ndis += nscan;
516
-
517
- int nres = 0;
518
- for (int b = 0; b < nBuckets && nres < k; b++) {
519
- for (int l = 0; l < csi.counters[b] && nres < k; l++) {
520
- labels[i * k + nres] = csi.ids_per_dis[b * k + l];
521
- distances[i * k + nres] = b;
522
- nres++;
523
- }
524
- }
525
- while (nres < k) {
526
- labels[i * k + nres] = -1;
527
- distances[i * k + nres] = std::numeric_limits<int32_t>::max();
528
- ++nres;
470
+ template <class HammingComputer, bool store_pairs>
471
+ void search_knn_hamming_count(
472
+ const IndexBinaryIVF& ivf,
473
+ size_t nx,
474
+ const uint8_t* x,
475
+ const idx_t* keys,
476
+ int k,
477
+ int32_t* distances,
478
+ idx_t* labels,
479
+ const IVFSearchParameters* params) {
480
+ const int nBuckets = ivf.d + 1;
481
+ std::vector<int> all_counters(nx * nBuckets, 0);
482
+ std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
483
+
484
+ idx_t nprobe = params ? params->nprobe : ivf.nprobe;
485
+ nprobe = std::min((idx_t)ivf.nlist, nprobe);
486
+ idx_t max_codes = params ? params->max_codes : ivf.max_codes;
487
+
488
+ std::vector<HCounterState<HammingComputer>> cs;
489
+ for (size_t i = 0; i < nx; ++i) {
490
+ cs.push_back(HCounterState<HammingComputer>(
491
+ all_counters.data() + i * nBuckets,
492
+ all_ids_per_dis.get() + i * nBuckets * k,
493
+ x + i * ivf.code_size,
494
+ ivf.d,
495
+ k));
529
496
  }
530
- }
531
497
 
532
- indexIVF_stats.nq += nx;
533
- indexIVF_stats.nlist += nlistv;
534
- indexIVF_stats.ndis += ndis;
535
- }
498
+ size_t nlistv = 0, ndis = 0;
536
499
 
500
+ #pragma omp parallel for reduction(+ : nlistv, ndis)
501
+ for (int64_t i = 0; i < nx; i++) {
502
+ const idx_t* keysi = keys + i * nprobe;
503
+ HCounterState<HammingComputer>& csi = cs[i];
537
504
 
505
+ size_t nscan = 0;
538
506
 
539
- template<bool store_pairs>
540
- void search_knn_hamming_count_1 (
541
- const IndexBinaryIVF& ivf,
542
- size_t nx,
543
- const uint8_t *x,
544
- const idx_t *keys,
545
- int k,
546
- int32_t *distances,
547
- idx_t *labels,
548
- const IVFSearchParameters *params) {
549
- switch (ivf.code_size) {
550
- #define HANDLE_CS(cs) \
551
- case cs: \
552
- search_knn_hamming_count<HammingComputer ## cs, store_pairs>( \
553
- ivf, nx, x, keys, k, distances, labels, params); \
554
- break;
555
- HANDLE_CS(4);
556
- HANDLE_CS(8);
557
- HANDLE_CS(16);
558
- HANDLE_CS(20);
559
- HANDLE_CS(32);
560
- HANDLE_CS(64);
561
- #undef HANDLE_CS
562
- default:
563
- if (ivf.code_size % 8 == 0) {
564
- search_knn_hamming_count<HammingComputerM8, store_pairs>
565
- (ivf, nx, x, keys, k, distances, labels, params);
566
- } else if (ivf.code_size % 4 == 0) {
567
- search_knn_hamming_count<HammingComputerM4, store_pairs>
568
- (ivf, nx, x, keys, k, distances, labels, params);
569
- } else {
570
- search_knn_hamming_count<HammingComputerDefault, store_pairs>
571
- (ivf, nx, x, keys, k, distances, labels, params);
507
+ for (size_t ik = 0; ik < nprobe; ik++) {
508
+ idx_t key = keysi[ik]; /* select the list */
509
+ if (key < 0) {
510
+ // not enough centroids for multiprobe
511
+ continue;
512
+ }
513
+ FAISS_THROW_IF_NOT_FMT(
514
+ key < (idx_t)ivf.nlist,
515
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
516
+ key,
517
+ ik,
518
+ ivf.nlist);
519
+
520
+ nlistv++;
521
+ size_t list_size = ivf.invlists->list_size(key);
522
+ InvertedLists::ScopedCodes scodes(ivf.invlists, key);
523
+ const uint8_t* list_vecs = scodes.get();
524
+ const Index::idx_t* ids =
525
+ store_pairs ? nullptr : ivf.invlists->get_ids(key);
526
+
527
+ for (size_t j = 0; j < list_size; j++) {
528
+ const uint8_t* yj = list_vecs + ivf.code_size * j;
529
+
530
+ idx_t id = store_pairs ? (key << 32 | j) : ids[j];
531
+ csi.update_counter(yj, id);
532
+ }
533
+ if (ids)
534
+ ivf.invlists->release_ids(key, ids);
535
+
536
+ nscan += list_size;
537
+ if (max_codes && nscan >= max_codes)
538
+ break;
539
+ }
540
+ ndis += nscan;
541
+
542
+ int nres = 0;
543
+ for (int b = 0; b < nBuckets && nres < k; b++) {
544
+ for (int l = 0; l < csi.counters[b] && nres < k; l++) {
545
+ labels[i * k + nres] = csi.ids_per_dis[b * k + l];
546
+ distances[i * k + nres] = b;
547
+ nres++;
548
+ }
549
+ }
550
+ while (nres < k) {
551
+ labels[i * k + nres] = -1;
552
+ distances[i * k + nres] = std::numeric_limits<int32_t>::max();
553
+ ++nres;
572
554
  }
573
- break;
574
555
  }
575
556
 
557
+ indexIVF_stats.nq += nx;
558
+ indexIVF_stats.nlist += nlistv;
559
+ indexIVF_stats.ndis += ndis;
576
560
  }
577
561
 
578
- } // namespace
562
+ template <bool store_pairs>
563
+ void search_knn_hamming_count_1(
564
+ const IndexBinaryIVF& ivf,
565
+ size_t nx,
566
+ const uint8_t* x,
567
+ const idx_t* keys,
568
+ int k,
569
+ int32_t* distances,
570
+ idx_t* labels,
571
+ const IVFSearchParameters* params) {
572
+ switch (ivf.code_size) {
573
+ #define HANDLE_CS(cs) \
574
+ case cs: \
575
+ search_knn_hamming_count<HammingComputer##cs, store_pairs>( \
576
+ ivf, nx, x, keys, k, distances, labels, params); \
577
+ break;
578
+ HANDLE_CS(4);
579
+ HANDLE_CS(8);
580
+ HANDLE_CS(16);
581
+ HANDLE_CS(20);
582
+ HANDLE_CS(32);
583
+ HANDLE_CS(64);
584
+ #undef HANDLE_CS
585
+ default:
586
+ search_knn_hamming_count<HammingComputerDefault, store_pairs>(
587
+ ivf, nx, x, keys, k, distances, labels, params);
588
+ break;
589
+ }
590
+ }
579
591
 
580
- BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
581
- (bool store_pairs) const
582
- {
592
+ } // namespace
583
593
 
584
- #define HC(name) return new IVFBinaryScannerL2<name> (code_size, store_pairs)
594
+ BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
595
+ bool store_pairs) const {
596
+ #define HC(name) return new IVFBinaryScannerL2<name>(code_size, store_pairs)
585
597
  switch (code_size) {
586
- case 4: HC(HammingComputer4);
587
- case 8: HC(HammingComputer8);
588
- case 16: HC(HammingComputer16);
589
- case 20: HC(HammingComputer20);
590
- case 32: HC(HammingComputer32);
591
- case 64: HC(HammingComputer64);
592
- default:
593
- if (code_size % 8 == 0) {
594
- HC(HammingComputerM8);
595
- } else if (code_size % 4 == 0) {
596
- HC(HammingComputerM4);
597
- } else {
598
+ case 4:
599
+ HC(HammingComputer4);
600
+ case 8:
601
+ HC(HammingComputer8);
602
+ case 16:
603
+ HC(HammingComputer16);
604
+ case 20:
605
+ HC(HammingComputer20);
606
+ case 32:
607
+ HC(HammingComputer32);
608
+ case 64:
609
+ HC(HammingComputer64);
610
+ default:
598
611
  HC(HammingComputerDefault);
599
- }
600
612
  }
601
613
  #undef HC
602
-
603
614
  }
604
615
 
605
- void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
606
- const idx_t *idx,
607
- const int32_t * coarse_dis,
608
- int32_t *distances, idx_t *labels,
609
- bool store_pairs,
610
- const IVFSearchParameters *params
611
- ) const {
612
-
616
+ void IndexBinaryIVF::search_preassigned(
617
+ idx_t n,
618
+ const uint8_t* x,
619
+ idx_t k,
620
+ const idx_t* idx,
621
+ const int32_t* coarse_dis,
622
+ int32_t* distances,
623
+ idx_t* labels,
624
+ bool store_pairs,
625
+ const IVFSearchParameters* params) const {
613
626
  if (use_heap) {
614
- search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
615
- distances, labels, store_pairs,
616
- params);
627
+ search_knn_hamming_heap(
628
+ *this,
629
+ n,
630
+ x,
631
+ k,
632
+ idx,
633
+ coarse_dis,
634
+ distances,
635
+ labels,
636
+ store_pairs,
637
+ params);
617
638
  } else {
618
639
  if (store_pairs) {
619
- search_knn_hamming_count_1<true>
620
- (*this, n, x, idx, k, distances, labels, params);
640
+ search_knn_hamming_count_1<true>(
641
+ *this, n, x, idx, k, distances, labels, params);
621
642
  } else {
622
- search_knn_hamming_count_1<false>
623
- (*this, n, x, idx, k, distances, labels, params);
643
+ search_knn_hamming_count_1<false>(
644
+ *this, n, x, idx, k, distances, labels, params);
624
645
  }
625
646
  }
626
647
  }
627
648
 
628
-
629
649
  void IndexBinaryIVF::range_search(
630
- idx_t n, const uint8_t *x, int radius,
631
- RangeSearchResult *res) const
632
- {
633
-
650
+ idx_t n,
651
+ const uint8_t* x,
652
+ int radius,
653
+ RangeSearchResult* res) const {
654
+ const size_t nprobe = std::min(nlist, this->nprobe);
634
655
  std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
635
656
  std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
636
657
 
@@ -641,77 +662,84 @@ void IndexBinaryIVF::range_search(
641
662
  t0 = getmillisecs();
642
663
  invlists->prefetch_lists(idx.get(), n * nprobe);
643
664
 
665
+ range_search_preassigned(n, x, radius, idx.get(), coarse_dis.get(), res);
666
+
667
+ indexIVF_stats.search_time += getmillisecs() - t0;
668
+ }
669
+
670
+ void IndexBinaryIVF::range_search_preassigned(
671
+ idx_t n,
672
+ const uint8_t* x,
673
+ int radius,
674
+ const idx_t* assign,
675
+ const int32_t* centroid_dis,
676
+ RangeSearchResult* res) const {
677
+ const size_t nprobe = std::min(nlist, this->nprobe);
644
678
  bool store_pairs = false;
645
679
  size_t nlistv = 0, ndis = 0;
646
680
 
647
- std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
681
+ std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
648
682
 
649
- #pragma omp parallel reduction(+: nlistv, ndis)
683
+ #pragma omp parallel reduction(+ : nlistv, ndis)
650
684
  {
651
685
  RangeSearchPartialResult pres(res);
652
- std::unique_ptr<BinaryInvertedListScanner> scanner
653
- (get_InvertedListScanner(store_pairs));
654
- FAISS_THROW_IF_NOT (scanner.get ());
686
+ std::unique_ptr<BinaryInvertedListScanner> scanner(
687
+ get_InvertedListScanner(store_pairs));
688
+ FAISS_THROW_IF_NOT(scanner.get());
655
689
 
656
690
  all_pres[omp_get_thread_num()] = &pres;
657
691
 
658
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres)
659
- {
660
-
661
- idx_t key = idx[i * nprobe + ik]; /* select the list */
662
- if (key < 0) return;
663
- FAISS_THROW_IF_NOT_FMT (
664
- key < (idx_t) nlist,
692
+ auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
693
+ idx_t key = assign[i * nprobe + ik]; /* select the list */
694
+ if (key < 0)
695
+ return;
696
+ FAISS_THROW_IF_NOT_FMT(
697
+ key < (idx_t)nlist,
665
698
  "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
666
- key, ik, nlist);
699
+ key,
700
+ ik,
701
+ nlist);
667
702
  const size_t list_size = invlists->list_size(key);
668
703
 
669
- if (list_size == 0) return;
704
+ if (list_size == 0)
705
+ return;
670
706
 
671
- InvertedLists::ScopedCodes scodes (invlists, key);
672
- InvertedLists::ScopedIds ids (invlists, key);
707
+ InvertedLists::ScopedCodes scodes(invlists, key);
708
+ InvertedLists::ScopedIds ids(invlists, key);
673
709
 
674
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
710
+ scanner->set_list(key, assign[i * nprobe + ik]);
675
711
  nlistv++;
676
712
  ndis += list_size;
677
- scanner->scan_codes_range (list_size, scodes.get(),
678
- ids.get(), radius, qres);
713
+ scanner->scan_codes_range(
714
+ list_size, scodes.get(), ids.get(), radius, qres);
679
715
  };
680
716
 
681
717
  #pragma omp for
682
718
  for (idx_t i = 0; i < n; i++) {
683
- scanner->set_query (x + i * code_size);
719
+ scanner->set_query(x + i * code_size);
684
720
 
685
- RangeQueryResult & qres = pres.new_result (i);
721
+ RangeQueryResult& qres = pres.new_result(i);
686
722
 
687
723
  for (size_t ik = 0; ik < nprobe; ik++) {
688
- scan_list_func (i, ik, qres);
724
+ scan_list_func(i, ik, qres);
689
725
  }
690
-
691
726
  }
692
727
 
693
728
  pres.finalize();
694
-
695
729
  }
696
730
  indexIVF_stats.nq += n;
697
731
  indexIVF_stats.nlist += nlistv;
698
732
  indexIVF_stats.ndis += ndis;
699
- indexIVF_stats.search_time += getmillisecs() - t0;
700
-
701
733
  }
702
734
 
703
-
704
-
705
-
706
735
  IndexBinaryIVF::~IndexBinaryIVF() {
707
- if (own_invlists) {
708
- delete invlists;
709
- }
736
+ if (own_invlists) {
737
+ delete invlists;
738
+ }
710
739
 
711
- if (own_fields) {
712
- delete quantizer;
713
- }
740
+ if (own_fields) {
741
+ delete quantizer;
742
+ }
714
743
  }
715
744
 
716
-
717
- } // namespace faiss
745
+ } // namespace faiss