faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,7 +9,6 @@
9
9
 
10
10
  #include <faiss/IndexIVF.h>
11
11
 
12
-
13
12
  #include <omp.h>
14
13
  #include <mutex>
15
14
 
@@ -18,12 +17,12 @@
18
17
  #include <cstdio>
19
18
  #include <memory>
20
19
 
21
- #include <faiss/utils/utils.h>
22
20
  #include <faiss/utils/hamming.h>
21
+ #include <faiss/utils/utils.h>
23
22
 
24
- #include <faiss/impl/FaissAssert.h>
25
23
  #include <faiss/IndexFlat.h>
26
24
  #include <faiss/impl/AuxIndexStructures.h>
25
+ #include <faiss/impl/FaissAssert.h>
27
26
 
28
27
  namespace faiss {
29
28
 
@@ -34,99 +33,104 @@ using ScopedCodes = InvertedLists::ScopedCodes;
34
33
  * Level1Quantizer implementation
35
34
  ******************************************/
36
35
 
37
-
38
- Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist):
39
- quantizer (quantizer),
40
- nlist (nlist),
41
- quantizer_trains_alone (0),
42
- own_fields (false),
43
- clustering_index (nullptr)
44
- {
36
+ Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
37
+ : quantizer(quantizer),
38
+ nlist(nlist),
39
+ quantizer_trains_alone(0),
40
+ own_fields(false),
41
+ clustering_index(nullptr) {
45
42
  // here we set a low # iterations because this is typically used
46
43
  // for large clusterings (nb this is not used for the MultiIndex,
47
44
  // for which quantizer_trains_alone = true)
48
45
  cp.niter = 10;
49
46
  }
50
47
 
51
- Level1Quantizer::Level1Quantizer ():
52
- quantizer (nullptr),
53
- nlist (0),
54
- quantizer_trains_alone (0), own_fields (false),
55
- clustering_index (nullptr)
56
- {}
48
+ Level1Quantizer::Level1Quantizer()
49
+ : quantizer(nullptr),
50
+ nlist(0),
51
+ quantizer_trains_alone(0),
52
+ own_fields(false),
53
+ clustering_index(nullptr) {}
57
54
 
58
- Level1Quantizer::~Level1Quantizer ()
59
- {
60
- if (own_fields) delete quantizer;
55
+ Level1Quantizer::~Level1Quantizer() {
56
+ if (own_fields)
57
+ delete quantizer;
61
58
  }
62
59
 
63
- void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
64
- {
60
+ void Level1Quantizer::train_q1(
61
+ size_t n,
62
+ const float* x,
63
+ bool verbose,
64
+ MetricType metric_type) {
65
65
  size_t d = quantizer->d;
66
66
  if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
67
67
  if (verbose)
68
- printf ("IVF quantizer does not need training.\n");
68
+ printf("IVF quantizer does not need training.\n");
69
69
  } else if (quantizer_trains_alone == 1) {
70
70
  if (verbose)
71
- printf ("IVF quantizer trains alone...\n");
72
- quantizer->train (n, x);
71
+ printf("IVF quantizer trains alone...\n");
72
+ quantizer->train(n, x);
73
73
  quantizer->verbose = verbose;
74
- FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist,
75
- "nlist not consistent with quantizer size");
74
+ FAISS_THROW_IF_NOT_MSG(
75
+ quantizer->ntotal == nlist,
76
+ "nlist not consistent with quantizer size");
76
77
  } else if (quantizer_trains_alone == 0) {
77
78
  if (verbose)
78
- printf ("Training level-1 quantizer on %zd vectors in %zdD\n",
79
- n, d);
79
+ printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
80
80
 
81
- Clustering clus (d, nlist, cp);
81
+ Clustering clus(d, nlist, cp);
82
82
  quantizer->reset();
83
83
  if (clustering_index) {
84
- clus.train (n, x, *clustering_index);
85
- quantizer->add (nlist, clus.centroids.data());
84
+ clus.train(n, x, *clustering_index);
85
+ quantizer->add(nlist, clus.centroids.data());
86
86
  } else {
87
- clus.train (n, x, *quantizer);
87
+ clus.train(n, x, *quantizer);
88
88
  }
89
89
  quantizer->is_trained = true;
90
90
  } else if (quantizer_trains_alone == 2) {
91
91
  if (verbose) {
92
- printf (
93
- "Training L2 quantizer on %zd vectors in %zdD%s\n",
94
- n, d,
95
- clustering_index ? "(user provided index)" : "");
92
+ printf("Training L2 quantizer on %zd vectors in %zdD%s\n",
93
+ n,
94
+ d,
95
+ clustering_index ? "(user provided index)" : "");
96
96
  }
97
97
  // also accept spherical centroids because in that case
98
98
  // L2 and IP are equivalent
99
- FAISS_THROW_IF_NOT (
100
- metric_type == METRIC_L2 ||
101
- (metric_type == METRIC_INNER_PRODUCT && cp.spherical)
102
- );
99
+ FAISS_THROW_IF_NOT(
100
+ metric_type == METRIC_L2 ||
101
+ (metric_type == METRIC_INNER_PRODUCT && cp.spherical));
103
102
 
104
- Clustering clus (d, nlist, cp);
103
+ Clustering clus(d, nlist, cp);
105
104
  if (!clustering_index) {
106
- IndexFlatL2 assigner (d);
105
+ IndexFlatL2 assigner(d);
107
106
  clus.train(n, x, assigner);
108
107
  } else {
109
108
  clus.train(n, x, *clustering_index);
110
109
  }
111
- if (verbose)
112
- printf ("Adding centroids to quantizer\n");
113
- quantizer->add (nlist, clus.centroids.data());
110
+ if (verbose) {
111
+ printf("Adding centroids to quantizer\n");
112
+ }
113
+ if (!quantizer->is_trained) {
114
+ if (verbose) {
115
+ printf("But training it first on centroids table...\n");
116
+ }
117
+ quantizer->train(nlist, clus.centroids.data());
118
+ }
119
+ quantizer->add(nlist, clus.centroids.data());
114
120
  }
115
121
  }
116
122
 
117
- size_t Level1Quantizer::coarse_code_size () const
118
- {
123
+ size_t Level1Quantizer::coarse_code_size() const {
119
124
  size_t nl = nlist - 1;
120
125
  size_t nbyte = 0;
121
126
  while (nl > 0) {
122
- nbyte ++;
127
+ nbyte++;
123
128
  nl >>= 8;
124
129
  }
125
130
  return nbyte;
126
131
  }
127
132
 
128
- void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
129
- {
133
+ void Level1Quantizer::encode_listno(Index::idx_t list_no, uint8_t* code) const {
130
134
  // little endian
131
135
  size_t nl = nlist - 1;
132
136
  while (nl > 0) {
@@ -136,8 +140,7 @@ void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
136
140
  }
137
141
  }
138
142
 
139
- Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
140
- {
143
+ Index::idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
141
144
  size_t nl = nlist - 1;
142
145
  int64_t list_no = 0;
143
146
  int nbit = 0;
@@ -146,161 +149,198 @@ Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
146
149
  nbit += 8;
147
150
  nl >>= 8;
148
151
  }
149
- FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist);
152
+ FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
150
153
  return list_no;
151
154
  }
152
155
 
153
-
154
-
155
156
  /*****************************************
156
157
  * IndexIVF implementation
157
158
  ******************************************/
158
159
 
159
-
160
- IndexIVF::IndexIVF (Index * quantizer, size_t d,
161
- size_t nlist, size_t code_size,
162
- MetricType metric):
163
- Index (d, metric),
164
- Level1Quantizer (quantizer, nlist),
165
- invlists (new ArrayInvertedLists (nlist, code_size)),
166
- own_invlists (true),
167
- code_size (code_size),
168
- nprobe (1),
169
- max_codes (0),
170
- parallel_mode (0)
171
- {
172
- FAISS_THROW_IF_NOT (d == quantizer->d);
160
+ IndexIVF::IndexIVF(
161
+ Index* quantizer,
162
+ size_t d,
163
+ size_t nlist,
164
+ size_t code_size,
165
+ MetricType metric)
166
+ : Index(d, metric),
167
+ Level1Quantizer(quantizer, nlist),
168
+ invlists(new ArrayInvertedLists(nlist, code_size)),
169
+ own_invlists(true),
170
+ code_size(code_size),
171
+ nprobe(1),
172
+ max_codes(0),
173
+ parallel_mode(0) {
174
+ FAISS_THROW_IF_NOT(d == quantizer->d);
173
175
  is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
174
176
  // Spherical by default if the metric is inner_product
175
177
  if (metric_type == METRIC_INNER_PRODUCT) {
176
178
  cp.spherical = true;
177
179
  }
178
-
179
180
  }
180
181
 
181
- IndexIVF::IndexIVF ():
182
- invlists (nullptr), own_invlists (false),
183
- code_size (0),
184
- nprobe (1), max_codes (0), parallel_mode (0)
185
- {}
182
+ IndexIVF::IndexIVF()
183
+ : invlists(nullptr),
184
+ own_invlists(false),
185
+ code_size(0),
186
+ nprobe(1),
187
+ max_codes(0),
188
+ parallel_mode(0) {}
189
+
190
+ void IndexIVF::add(idx_t n, const float* x) {
191
+ add_with_ids(n, x, nullptr);
192
+ }
186
193
 
187
- void IndexIVF::add (idx_t n, const float * x)
188
- {
189
- add_with_ids (n, x, nullptr);
194
+ void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
195
+ std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
196
+ quantizer->assign(n, x, coarse_idx.get());
197
+ add_core(n, x, xids, coarse_idx.get());
190
198
  }
191
199
 
200
+ void IndexIVF::add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) {
201
+ size_t coarse_size = coarse_code_size();
202
+ DirectMapAdd dm_adder(direct_map, n, xids);
203
+
204
+ for (idx_t i = 0; i < n; i++) {
205
+ const uint8_t* code = codes + (code_size + coarse_size) * i;
206
+ idx_t list_no = decode_listno(code);
207
+ idx_t id = xids ? xids[i] : ntotal + i;
208
+ size_t ofs = invlists->add_entry(list_no, id, code + coarse_size);
209
+ dm_adder.add(i, list_no, ofs);
210
+ }
211
+ ntotal += n;
212
+ }
192
213
 
193
- void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
194
- {
214
+ void IndexIVF::add_core(
215
+ idx_t n,
216
+ const float* x,
217
+ const idx_t* xids,
218
+ const idx_t* coarse_idx) {
195
219
  // do some blocking to avoid excessive allocs
196
220
  idx_t bs = 65536;
197
221
  if (n > bs) {
198
222
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
199
- idx_t i1 = std::min (n, i0 + bs);
223
+ idx_t i1 = std::min(n, i0 + bs);
200
224
  if (verbose) {
201
- printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n", i0, i1);
225
+ printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n",
226
+ i0,
227
+ i1);
202
228
  }
203
- add_with_ids (i1 - i0, x + i0 * d,
204
- xids ? xids + i0 : nullptr);
229
+ add_core(
230
+ i1 - i0,
231
+ x + i0 * d,
232
+ xids ? xids + i0 : nullptr,
233
+ coarse_idx + i0);
205
234
  }
206
235
  return;
207
236
  }
237
+ FAISS_THROW_IF_NOT(coarse_idx);
238
+ FAISS_THROW_IF_NOT(is_trained);
239
+ direct_map.check_can_add(xids);
208
240
 
209
- FAISS_THROW_IF_NOT (is_trained);
210
- direct_map.check_can_add (xids);
211
-
212
- std::unique_ptr<idx_t []> idx(new idx_t[n]);
213
- quantizer->assign (n, x, idx.get());
214
241
  size_t nadd = 0, nminus1 = 0;
215
242
 
216
243
  for (size_t i = 0; i < n; i++) {
217
- if (idx[i] < 0) nminus1++;
244
+ if (coarse_idx[i] < 0)
245
+ nminus1++;
218
246
  }
219
247
 
220
- std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
221
- encode_vectors (n, x, idx.get(), flat_codes.get());
248
+ std::unique_ptr<uint8_t[]> flat_codes(new uint8_t[n * code_size]);
249
+ encode_vectors(n, x, coarse_idx, flat_codes.get());
222
250
 
223
251
  DirectMapAdd dm_adder(direct_map, n, xids);
224
252
 
225
- #pragma omp parallel reduction(+: nadd)
253
+ #pragma omp parallel reduction(+ : nadd)
226
254
  {
227
255
  int nt = omp_get_num_threads();
228
256
  int rank = omp_get_thread_num();
229
257
 
230
258
  // each thread takes care of a subset of lists
231
259
  for (size_t i = 0; i < n; i++) {
232
- idx_t list_no = idx [i];
260
+ idx_t list_no = coarse_idx[i];
233
261
  if (list_no >= 0 && list_no % nt == rank) {
234
262
  idx_t id = xids ? xids[i] : ntotal + i;
235
- size_t ofs = invlists->add_entry (
236
- list_no, id,
237
- flat_codes.get() + i * code_size
238
- );
263
+ size_t ofs = invlists->add_entry(
264
+ list_no, id, flat_codes.get() + i * code_size);
239
265
 
240
- dm_adder.add (i, list_no, ofs);
266
+ dm_adder.add(i, list_no, ofs);
241
267
 
242
268
  nadd++;
243
269
  } else if (rank == 0 && list_no == -1) {
244
- dm_adder.add (i, -1, 0);
270
+ dm_adder.add(i, -1, 0);
245
271
  }
246
272
  }
247
273
  }
248
274
 
249
-
250
275
  if (verbose) {
251
- printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n", nadd, n, nminus1);
276
+ printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n",
277
+ nadd,
278
+ n,
279
+ nminus1);
252
280
  }
253
281
 
254
282
  ntotal += n;
255
283
  }
256
284
 
257
- void IndexIVF::make_direct_map (bool b)
258
- {
285
+ void IndexIVF::make_direct_map(bool b) {
259
286
  if (b) {
260
- direct_map.set_type (DirectMap::Array, invlists, ntotal);
287
+ direct_map.set_type(DirectMap::Array, invlists, ntotal);
261
288
  } else {
262
- direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
289
+ direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
263
290
  }
264
291
  }
265
292
 
266
-
267
-
268
- void IndexIVF::set_direct_map_type (DirectMap::Type type)
269
- {
270
- direct_map.set_type (type, invlists, ntotal);
293
+ void IndexIVF::set_direct_map_type(DirectMap::Type type) {
294
+ direct_map.set_type(type, invlists, ntotal);
271
295
  }
272
296
 
273
297
  /** It is a sad fact of software that a conceptually simple function like this
274
298
  * becomes very complex when you factor in several ways of parallelizing +
275
299
  * interrupt/error handling + collecting stats + min/max collection. The
276
300
  * codepath that is used 95% of time is the one for parallel_mode = 0 */
277
- void IndexIVF::search (idx_t n, const float *x, idx_t k,
278
- float *distances, idx_t *labels) const
279
- {
301
+ void IndexIVF::search(
302
+ idx_t n,
303
+ const float* x,
304
+ idx_t k,
305
+ float* distances,
306
+ idx_t* labels) const {
307
+ FAISS_THROW_IF_NOT(k > 0);
280
308
 
309
+ const size_t nprobe = std::min(nlist, this->nprobe);
310
+ FAISS_THROW_IF_NOT(nprobe > 0);
281
311
 
282
312
  // search function for a subset of queries
283
- auto sub_search_func = [this, k]
284
- (idx_t n, const float *x, float *distances, idx_t *labels,
285
- IndexIVFStats *ivf_stats) {
286
-
313
+ auto sub_search_func = [this, k, nprobe](
314
+ idx_t n,
315
+ const float* x,
316
+ float* distances,
317
+ idx_t* labels,
318
+ IndexIVFStats* ivf_stats) {
287
319
  std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
288
320
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
289
321
 
290
322
  double t0 = getmillisecs();
291
- quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
323
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
292
324
 
293
325
  double t1 = getmillisecs();
294
- invlists->prefetch_lists (idx.get(), n * nprobe);
295
-
296
- search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
297
- distances, labels, false, nullptr, ivf_stats);
326
+ invlists->prefetch_lists(idx.get(), n * nprobe);
327
+
328
+ search_preassigned(
329
+ n,
330
+ x,
331
+ k,
332
+ idx.get(),
333
+ coarse_dis.get(),
334
+ distances,
335
+ labels,
336
+ false,
337
+ nullptr,
338
+ ivf_stats);
298
339
  double t2 = getmillisecs();
299
340
  ivf_stats->quantization_time += t1 - t0;
300
341
  ivf_stats->search_time += t2 - t0;
301
342
  };
302
343
 
303
-
304
344
  if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
305
345
  int nt = std::min(omp_get_max_threads(), int(n));
306
346
  std::vector<IndexIVFStats> stats(nt);
@@ -308,18 +348,19 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
308
348
  std::string exception_string;
309
349
 
310
350
  #pragma omp parallel for if (nt > 1)
311
- for(idx_t slice = 0; slice < nt; slice++) {
351
+ for (idx_t slice = 0; slice < nt; slice++) {
312
352
  IndexIVFStats local_stats;
313
353
  idx_t i0 = n * slice / nt;
314
354
  idx_t i1 = n * (slice + 1) / nt;
315
355
  if (i1 > i0) {
316
356
  try {
317
357
  sub_search_func(
318
- i1 - i0, x + i0 * d,
319
- distances + i0 * k, labels + i0 * k,
320
- &stats[slice]
321
- );
322
- } catch(const std::exception & e) {
358
+ i1 - i0,
359
+ x + i0 * d,
360
+ distances + i0 * k,
361
+ labels + i0 * k,
362
+ &stats[slice]);
363
+ } catch (const std::exception& e) {
323
364
  std::lock_guard<std::mutex> lock(exception_mutex);
324
365
  exception_string = e.what();
325
366
  }
@@ -327,32 +368,38 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
327
368
  }
328
369
 
329
370
  if (!exception_string.empty()) {
330
- FAISS_THROW_MSG (exception_string.c_str());
371
+ FAISS_THROW_MSG(exception_string.c_str());
331
372
  }
332
373
 
333
374
  // collect stats
334
- for(idx_t slice = 0; slice < nt; slice++) {
375
+ for (idx_t slice = 0; slice < nt; slice++) {
335
376
  indexIVF_stats.add(stats[slice]);
336
377
  }
337
378
  } else {
338
- // handle paralellization at level below (or don't run in parallel at all)
379
+ // handle paralellization at level below (or don't run in parallel at
380
+ // all)
339
381
  sub_search_func(n, x, distances, labels, &indexIVF_stats);
340
382
  }
341
-
342
-
343
383
  }
344
384
 
345
-
346
- void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
347
- const idx_t *keys,
348
- const float *coarse_dis ,
349
- float *distances, idx_t *labels,
350
- bool store_pairs,
351
- const IVFSearchParameters *params,
352
- IndexIVFStats *ivf_stats) const
353
- {
354
- long nprobe = params ? params->nprobe : this->nprobe;
355
- long max_codes = params ? params->max_codes : this->max_codes;
385
+ void IndexIVF::search_preassigned(
386
+ idx_t n,
387
+ const float* x,
388
+ idx_t k,
389
+ const idx_t* keys,
390
+ const float* coarse_dis,
391
+ float* distances,
392
+ idx_t* labels,
393
+ bool store_pairs,
394
+ const IVFSearchParameters* params,
395
+ IndexIVFStats* ivf_stats) const {
396
+ FAISS_THROW_IF_NOT(k > 0);
397
+
398
+ idx_t nprobe = params ? params->nprobe : this->nprobe;
399
+ nprobe = std::min((idx_t)nlist, nprobe);
400
+ FAISS_THROW_IF_NOT(nprobe > 0);
401
+
402
+ idx_t max_codes = params ? params->max_codes : this->max_codes;
356
403
 
357
404
  size_t nlistv = 0, ndis = 0, nheap = 0;
358
405
 
@@ -366,15 +413,15 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
366
413
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
367
414
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
368
415
 
369
- bool do_parallel = omp_get_max_threads() >= 2 && (
370
- pmode == 0 ? false :
371
- pmode == 3 ? n > 1 :
372
- pmode == 1 ? nprobe > 1 :
373
- nprobe * n > 1);
416
+ bool do_parallel = omp_get_max_threads() >= 2 &&
417
+ (pmode == 0 ? false
418
+ : pmode == 3 ? n > 1
419
+ : pmode == 1 ? nprobe > 1
420
+ : nprobe * n > 1);
374
421
 
375
- #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
422
+ #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
376
423
  {
377
- InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
424
+ InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
378
425
  ScopeDeleter1<InvertedListScanner> del(scanner);
379
426
 
380
427
  /*****************************************************
@@ -385,49 +432,52 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
385
432
 
386
433
  // intialize + reorder a result heap
387
434
 
388
- auto init_result = [&](float *simi, idx_t *idxi) {
389
- if (!do_heap_init) return;
435
+ auto init_result = [&](float* simi, idx_t* idxi) {
436
+ if (!do_heap_init)
437
+ return;
390
438
  if (metric_type == METRIC_INNER_PRODUCT) {
391
- heap_heapify<HeapForIP> (k, simi, idxi);
439
+ heap_heapify<HeapForIP>(k, simi, idxi);
392
440
  } else {
393
- heap_heapify<HeapForL2> (k, simi, idxi);
441
+ heap_heapify<HeapForL2>(k, simi, idxi);
394
442
  }
395
443
  };
396
444
 
397
- auto add_local_results = [&](
398
- const float * local_dis, const idx_t * local_idx,
399
- float *simi, idx_t *idxi)
400
- {
445
+ auto add_local_results = [&](const float* local_dis,
446
+ const idx_t* local_idx,
447
+ float* simi,
448
+ idx_t* idxi) {
401
449
  if (metric_type == METRIC_INNER_PRODUCT) {
402
- heap_addn<HeapForIP>
403
- (k, simi, idxi, local_dis, local_idx, k);
450
+ heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
404
451
  } else {
405
- heap_addn<HeapForL2>
406
- (k, simi, idxi, local_dis, local_idx, k);
452
+ heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
407
453
  }
408
454
  };
409
455
 
410
- auto reorder_result = [&] (float *simi, idx_t *idxi) {
411
- if (!do_heap_init) return;
456
+ auto reorder_result = [&](float* simi, idx_t* idxi) {
457
+ if (!do_heap_init)
458
+ return;
412
459
  if (metric_type == METRIC_INNER_PRODUCT) {
413
- heap_reorder<HeapForIP> (k, simi, idxi);
460
+ heap_reorder<HeapForIP>(k, simi, idxi);
414
461
  } else {
415
- heap_reorder<HeapForL2> (k, simi, idxi);
462
+ heap_reorder<HeapForL2>(k, simi, idxi);
416
463
  }
417
464
  };
418
465
 
419
466
  // single list scan using the current scanner (with query
420
467
  // set porperly) and storing results in simi and idxi
421
- auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
422
- float *simi, idx_t *idxi) {
423
-
468
+ auto scan_one_list = [&](idx_t key,
469
+ float coarse_dis_i,
470
+ float* simi,
471
+ idx_t* idxi) {
424
472
  if (key < 0) {
425
473
  // not enough centroids for multiprobe
426
474
  return (size_t)0;
427
475
  }
428
- FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
429
- "Invalid key=%" PRId64 " nlist=%zd\n",
430
- key, nlist);
476
+ FAISS_THROW_IF_NOT_FMT(
477
+ key < (idx_t)nlist,
478
+ "Invalid key=%" PRId64 " nlist=%zd\n",
479
+ key,
480
+ nlist);
431
481
 
432
482
  size_t list_size = invlists->list_size(key);
433
483
 
@@ -436,28 +486,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
436
486
  return (size_t)0;
437
487
  }
438
488
 
439
- scanner->set_list (key, coarse_dis_i);
489
+ scanner->set_list(key, coarse_dis_i);
440
490
 
441
491
  nlistv++;
442
492
 
443
493
  try {
444
- InvertedLists::ScopedCodes scodes (invlists, key);
494
+ InvertedLists::ScopedCodes scodes(invlists, key);
445
495
 
446
496
  std::unique_ptr<InvertedLists::ScopedIds> sids;
447
- const Index::idx_t * ids = nullptr;
497
+ const Index::idx_t* ids = nullptr;
448
498
 
449
- if (!store_pairs) {
450
- sids.reset (new InvertedLists::ScopedIds (invlists, key));
499
+ if (!store_pairs) {
500
+ sids.reset(new InvertedLists::ScopedIds(invlists, key));
451
501
  ids = sids->get();
452
502
  }
453
503
 
454
- nheap += scanner->scan_codes (list_size, scodes.get(),
455
- ids, simi, idxi, k);
504
+ nheap += scanner->scan_codes(
505
+ list_size, scodes.get(), ids, simi, idxi, k);
456
506
 
457
- } catch(const std::exception & e) {
507
+ } catch (const std::exception& e) {
458
508
  std::lock_guard<std::mutex> lock(exception_mutex);
459
509
  exception_string =
460
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
510
+ demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
461
511
  interrupt = true;
462
512
  return size_t(0);
463
513
  }
@@ -470,31 +520,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
470
520
  ****************************************************/
471
521
 
472
522
  if (pmode == 0 || pmode == 3) {
473
-
474
523
  #pragma omp for
475
524
  for (idx_t i = 0; i < n; i++) {
476
-
477
525
  if (interrupt) {
478
526
  continue;
479
527
  }
480
528
 
481
529
  // loop over queries
482
- scanner->set_query (x + i * d);
483
- float * simi = distances + i * k;
484
- idx_t * idxi = labels + i * k;
530
+ scanner->set_query(x + i * d);
531
+ float* simi = distances + i * k;
532
+ idx_t* idxi = labels + i * k;
485
533
 
486
- init_result (simi, idxi);
534
+ init_result(simi, idxi);
487
535
 
488
- long nscan = 0;
536
+ idx_t nscan = 0;
489
537
 
490
538
  // loop over probes
491
539
  for (size_t ik = 0; ik < nprobe; ik++) {
492
-
493
- nscan += scan_one_list (
494
- keys [i * nprobe + ik],
495
- coarse_dis[i * nprobe + ik],
496
- simi, idxi
497
- );
540
+ nscan += scan_one_list(
541
+ keys[i * nprobe + ik],
542
+ coarse_dis[i * nprobe + ik],
543
+ simi,
544
+ idxi);
498
545
 
499
546
  if (max_codes && nscan >= max_codes) {
500
547
  break;
@@ -502,54 +549,55 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
502
549
  }
503
550
 
504
551
  ndis += nscan;
505
- reorder_result (simi, idxi);
552
+ reorder_result(simi, idxi);
506
553
 
507
- if (InterruptCallback::is_interrupted ()) {
554
+ if (InterruptCallback::is_interrupted()) {
508
555
  interrupt = true;
509
556
  }
510
557
 
511
558
  } // parallel for
512
559
  } else if (pmode == 1) {
513
- std::vector <idx_t> local_idx (k);
514
- std::vector <float> local_dis (k);
560
+ std::vector<idx_t> local_idx(k);
561
+ std::vector<float> local_dis(k);
515
562
 
516
563
  for (size_t i = 0; i < n; i++) {
517
- scanner->set_query (x + i * d);
518
- init_result (local_dis.data(), local_idx.data());
564
+ scanner->set_query(x + i * d);
565
+ init_result(local_dis.data(), local_idx.data());
519
566
 
520
567
  #pragma omp for schedule(dynamic)
521
- for (long ik = 0; ik < nprobe; ik++) {
522
- ndis += scan_one_list
523
- (keys [i * nprobe + ik],
524
- coarse_dis[i * nprobe + ik],
525
- local_dis.data(), local_idx.data());
568
+ for (idx_t ik = 0; ik < nprobe; ik++) {
569
+ ndis += scan_one_list(
570
+ keys[i * nprobe + ik],
571
+ coarse_dis[i * nprobe + ik],
572
+ local_dis.data(),
573
+ local_idx.data());
526
574
 
527
575
  // can't do the test on max_codes
528
576
  }
529
577
  // merge thread-local results
530
578
 
531
- float * simi = distances + i * k;
532
- idx_t * idxi = labels + i * k;
579
+ float* simi = distances + i * k;
580
+ idx_t* idxi = labels + i * k;
533
581
  #pragma omp single
534
- init_result (simi, idxi);
582
+ init_result(simi, idxi);
535
583
 
536
584
  #pragma omp barrier
537
585
  #pragma omp critical
538
586
  {
539
- add_local_results (local_dis.data(), local_idx.data(),
540
- simi, idxi);
587
+ add_local_results(
588
+ local_dis.data(), local_idx.data(), simi, idxi);
541
589
  }
542
590
  #pragma omp barrier
543
591
  #pragma omp single
544
- reorder_result (simi, idxi);
592
+ reorder_result(simi, idxi);
545
593
  }
546
594
  } else if (pmode == 2) {
547
- std::vector <idx_t> local_idx (k);
548
- std::vector <float> local_dis (k);
595
+ std::vector<idx_t> local_idx(k);
596
+ std::vector<float> local_dis(k);
549
597
 
550
598
  #pragma omp single
551
599
  for (int64_t i = 0; i < n; i++) {
552
- init_result (distances + i * k, labels + i * k);
600
+ init_result(distances + i * k, labels + i * k);
553
601
  }
554
602
 
555
603
  #pragma omp for schedule(dynamic)
@@ -557,33 +605,37 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
557
605
  size_t i = ij / nprobe;
558
606
  size_t j = ij % nprobe;
559
607
 
560
- scanner->set_query (x + i * d);
561
- init_result (local_dis.data(), local_idx.data());
562
- ndis += scan_one_list (
563
- keys [ij], coarse_dis[ij],
564
- local_dis.data(), local_idx.data());
608
+ scanner->set_query(x + i * d);
609
+ init_result(local_dis.data(), local_idx.data());
610
+ ndis += scan_one_list(
611
+ keys[ij],
612
+ coarse_dis[ij],
613
+ local_dis.data(),
614
+ local_idx.data());
565
615
  #pragma omp critical
566
616
  {
567
- add_local_results (local_dis.data(), local_idx.data(),
568
- distances + i * k, labels + i * k);
617
+ add_local_results(
618
+ local_dis.data(),
619
+ local_idx.data(),
620
+ distances + i * k,
621
+ labels + i * k);
569
622
  }
570
623
  }
571
624
  #pragma omp single
572
625
  for (int64_t i = 0; i < n; i++) {
573
- reorder_result (distances + i * k, labels + i * k);
626
+ reorder_result(distances + i * k, labels + i * k);
574
627
  }
575
628
  } else {
576
- FAISS_THROW_FMT ("parallel_mode %d not supported\n",
577
- pmode);
629
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
578
630
  }
579
631
  } // parallel section
580
632
 
581
633
  if (interrupt) {
582
634
  if (!exception_string.empty()) {
583
- FAISS_THROW_FMT ("search interrupted with: %s",
584
- exception_string.c_str());
635
+ FAISS_THROW_FMT(
636
+ "search interrupted with: %s", exception_string.c_str());
585
637
  } else {
586
- FAISS_THROW_MSG ("computation interrupted");
638
+ FAISS_THROW_MSG("computation interrupted");
587
639
  }
588
640
  }
589
641
 
@@ -595,38 +647,49 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
595
647
  }
596
648
  }
597
649
 
598
-
599
-
600
-
601
- void IndexIVF::range_search (idx_t nx, const float *x, float radius,
602
- RangeSearchResult *result) const
603
- {
604
- std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
605
- std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
650
+ void IndexIVF::range_search(
651
+ idx_t nx,
652
+ const float* x,
653
+ float radius,
654
+ RangeSearchResult* result) const {
655
+ const size_t nprobe = std::min(nlist, this->nprobe);
656
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
657
+ std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
606
658
 
607
659
  double t0 = getmillisecs();
608
- quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ());
660
+ quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
609
661
  indexIVF_stats.quantization_time += getmillisecs() - t0;
610
662
 
611
663
  t0 = getmillisecs();
612
- invlists->prefetch_lists (keys.get(), nx * nprobe);
613
-
614
- range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
615
- result, false, nullptr, &indexIVF_stats);
664
+ invlists->prefetch_lists(keys.get(), nx * nprobe);
665
+
666
+ range_search_preassigned(
667
+ nx,
668
+ x,
669
+ radius,
670
+ keys.get(),
671
+ coarse_dis.get(),
672
+ result,
673
+ false,
674
+ nullptr,
675
+ &indexIVF_stats);
616
676
 
617
677
  indexIVF_stats.search_time += getmillisecs() - t0;
618
678
  }
619
679
 
620
- void IndexIVF::range_search_preassigned (
621
- idx_t nx, const float *x, float radius,
622
- const idx_t *keys, const float *coarse_dis,
623
- RangeSearchResult *result,
624
- bool store_pairs,
625
- const IVFSearchParameters *params,
626
- IndexIVFStats *stats) const
627
- {
628
- long nprobe = params ? params->nprobe : this->nprobe;
629
- long max_codes = params ? params->max_codes : this->max_codes;
680
+ void IndexIVF::range_search_preassigned(
681
+ idx_t nx,
682
+ const float* x,
683
+ float radius,
684
+ const idx_t* keys,
685
+ const float* coarse_dis,
686
+ RangeSearchResult* result,
687
+ bool store_pairs,
688
+ const IVFSearchParameters* params,
689
+ IndexIVFStats* stats) const {
690
+ idx_t nprobe = params ? params->nprobe : this->nprobe;
691
+ nprobe = std::min((idx_t)nlist, nprobe);
692
+ idx_t max_codes = params ? params->max_codes : this->max_codes;
630
693
 
631
694
  size_t nlistv = 0, ndis = 0;
632
695
 
@@ -634,119 +697,116 @@ void IndexIVF::range_search_preassigned (
634
697
  std::mutex exception_mutex;
635
698
  std::string exception_string;
636
699
 
637
- std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
700
+ std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
638
701
 
639
702
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
640
703
  // don't start parallel section if single query
641
- bool do_parallel = omp_get_max_threads() >= 2 && (
642
- pmode == 3 ? false :
643
- pmode == 0 ? nx > 1 :
644
- pmode == 1 ? nprobe > 1 :
645
- nprobe * nx > 1);
704
+ bool do_parallel = omp_get_max_threads() >= 2 &&
705
+ (pmode == 3 ? false
706
+ : pmode == 0 ? nx > 1
707
+ : pmode == 1 ? nprobe > 1
708
+ : nprobe * nx > 1);
646
709
 
647
- #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
710
+ #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
648
711
  {
649
712
  RangeSearchPartialResult pres(result);
650
- std::unique_ptr<InvertedListScanner> scanner
651
- (get_InvertedListScanner(store_pairs));
652
- FAISS_THROW_IF_NOT (scanner.get ());
713
+ std::unique_ptr<InvertedListScanner> scanner(
714
+ get_InvertedListScanner(store_pairs));
715
+ FAISS_THROW_IF_NOT(scanner.get());
653
716
  all_pres[omp_get_thread_num()] = &pres;
654
717
 
655
718
  // prepare the list scanning function
656
719
 
657
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) {
658
-
659
- idx_t key = keys[i * nprobe + ik]; /* select the list */
660
- if (key < 0) return;
661
- FAISS_THROW_IF_NOT_FMT (
662
- key < (idx_t) nlist,
663
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
664
- key, ik, nlist);
720
+ auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
721
+ idx_t key = keys[i * nprobe + ik]; /* select the list */
722
+ if (key < 0)
723
+ return;
724
+ FAISS_THROW_IF_NOT_FMT(
725
+ key < (idx_t)nlist,
726
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
727
+ key,
728
+ ik,
729
+ nlist);
665
730
  const size_t list_size = invlists->list_size(key);
666
731
 
667
- if (list_size == 0) return;
732
+ if (list_size == 0)
733
+ return;
668
734
 
669
735
  try {
736
+ InvertedLists::ScopedCodes scodes(invlists, key);
737
+ InvertedLists::ScopedIds ids(invlists, key);
670
738
 
671
- InvertedLists::ScopedCodes scodes (invlists, key);
672
- InvertedLists::ScopedIds ids (invlists, key);
673
-
674
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
739
+ scanner->set_list(key, coarse_dis[i * nprobe + ik]);
675
740
  nlistv++;
676
741
  ndis += list_size;
677
- scanner->scan_codes_range (list_size, scodes.get(),
678
- ids.get(), radius, qres);
742
+ scanner->scan_codes_range(
743
+ list_size, scodes.get(), ids.get(), radius, qres);
679
744
 
680
- } catch(const std::exception & e) {
745
+ } catch (const std::exception& e) {
681
746
  std::lock_guard<std::mutex> lock(exception_mutex);
682
747
  exception_string =
683
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
748
+ demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
684
749
  interrupt = true;
685
750
  }
686
-
687
751
  };
688
752
 
689
753
  if (parallel_mode == 0) {
690
-
691
754
  #pragma omp for
692
755
  for (idx_t i = 0; i < nx; i++) {
693
- scanner->set_query (x + i * d);
756
+ scanner->set_query(x + i * d);
694
757
 
695
- RangeQueryResult & qres = pres.new_result (i);
758
+ RangeQueryResult& qres = pres.new_result(i);
696
759
 
697
760
  for (size_t ik = 0; ik < nprobe; ik++) {
698
- scan_list_func (i, ik, qres);
761
+ scan_list_func(i, ik, qres);
699
762
  }
700
-
701
763
  }
702
764
 
703
765
  } else if (parallel_mode == 1) {
704
-
705
766
  for (size_t i = 0; i < nx; i++) {
706
- scanner->set_query (x + i * d);
767
+ scanner->set_query(x + i * d);
707
768
 
708
- RangeQueryResult & qres = pres.new_result (i);
769
+ RangeQueryResult& qres = pres.new_result(i);
709
770
 
710
771
  #pragma omp for schedule(dynamic)
711
772
  for (int64_t ik = 0; ik < nprobe; ik++) {
712
- scan_list_func (i, ik, qres);
773
+ scan_list_func(i, ik, qres);
713
774
  }
714
775
  }
715
776
  } else if (parallel_mode == 2) {
716
- std::vector<RangeQueryResult *> all_qres (nx);
717
- RangeQueryResult *qres = nullptr;
777
+ std::vector<RangeQueryResult*> all_qres(nx);
778
+ RangeQueryResult* qres = nullptr;
718
779
 
719
780
  #pragma omp for schedule(dynamic)
720
781
  for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
721
782
  idx_t i = iik / (idx_t)nprobe;
722
783
  idx_t ik = iik % (idx_t)nprobe;
723
784
  if (qres == nullptr || qres->qno != i) {
724
- FAISS_ASSERT (!qres || i > qres->qno);
725
- qres = &pres.new_result (i);
726
- scanner->set_query (x + i * d);
785
+ FAISS_ASSERT(!qres || i > qres->qno);
786
+ qres = &pres.new_result(i);
787
+ scanner->set_query(x + i * d);
727
788
  }
728
- scan_list_func (i, ik, *qres);
789
+ scan_list_func(i, ik, *qres);
729
790
  }
730
791
  } else {
731
- FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode);
792
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
732
793
  }
733
794
  if (parallel_mode == 0) {
734
- pres.finalize ();
795
+ pres.finalize();
735
796
  } else {
736
797
  #pragma omp barrier
737
798
  #pragma omp single
738
- RangeSearchPartialResult::merge (all_pres, false);
799
+ RangeSearchPartialResult::merge(all_pres, false);
739
800
  #pragma omp barrier
740
-
741
801
  }
742
802
  }
743
803
 
744
804
  if (interrupt) {
745
805
  if (!exception_string.empty()) {
746
- FAISS_THROW_FMT ("search interrupted with: %s",
747
- exception_string.c_str());
806
+ FAISS_THROW_FMT(
807
+ "search interrupted with: %s", exception_string.c_str());
748
808
  } else {
749
- FAISS_THROW_MSG ("computation interrupted");
809
+ FAISS_THROW_MSG("computation interrupted");
750
810
  }
751
811
  }
752
812
 
@@ -757,27 +817,22 @@ void IndexIVF::range_search_preassigned (
757
817
  }
758
818
  }
759
819
 
760
-
761
- InvertedListScanner *IndexIVF::get_InvertedListScanner (
762
- bool /*store_pairs*/) const
763
- {
820
+ InvertedListScanner* IndexIVF::get_InvertedListScanner(
821
+ bool /*store_pairs*/) const {
764
822
  return nullptr;
765
823
  }
766
824
 
767
- void IndexIVF::reconstruct (idx_t key, float* recons) const
768
- {
769
- idx_t lo = direct_map.get (key);
770
- reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
825
+ void IndexIVF::reconstruct(idx_t key, float* recons) const {
826
+ idx_t lo = direct_map.get(key);
827
+ reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
771
828
  }
772
829
 
773
-
774
- void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
775
- {
776
- FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
830
+ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
831
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
777
832
 
778
833
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
779
- size_t list_size = invlists->list_size (list_no);
780
- ScopedIds idlist (invlists, list_no);
834
+ size_t list_size = invlists->list_size(list_no);
835
+ ScopedIds idlist(invlists, list_no);
781
836
 
782
837
  for (idx_t offset = 0; offset < list_size; offset++) {
783
838
  idx_t id = idlist[offset];
@@ -786,46 +841,56 @@ void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
786
841
  }
787
842
 
788
843
  float* reconstructed = recons + (id - i0) * d;
789
- reconstruct_from_offset (list_no, offset, reconstructed);
844
+ reconstruct_from_offset(list_no, offset, reconstructed);
790
845
  }
791
846
  }
792
847
  }
793
848
 
794
-
795
849
  /* standalone codec interface */
796
- size_t IndexIVF::sa_code_size () const
797
- {
850
+ size_t IndexIVF::sa_code_size() const {
798
851
  size_t coarse_size = coarse_code_size();
799
852
  return code_size + coarse_size;
800
853
  }
801
854
 
802
- void IndexIVF::sa_encode (idx_t n, const float *x,
803
- uint8_t *bytes) const
804
- {
805
- FAISS_THROW_IF_NOT (is_trained);
806
- std::unique_ptr<int64_t []> idx (new int64_t [n]);
807
- quantizer->assign (n, x, idx.get());
808
- encode_vectors (n, x, idx.get(), bytes, true);
855
+ void IndexIVF::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
856
+ FAISS_THROW_IF_NOT(is_trained);
857
+ std::unique_ptr<int64_t[]> idx(new int64_t[n]);
858
+ quantizer->assign(n, x, idx.get());
859
+ encode_vectors(n, x, idx.get(), bytes, true);
809
860
  }
810
861
 
862
+ void IndexIVF::search_and_reconstruct(
863
+ idx_t n,
864
+ const float* x,
865
+ idx_t k,
866
+ float* distances,
867
+ idx_t* labels,
868
+ float* recons) const {
869
+ FAISS_THROW_IF_NOT(k > 0);
811
870
 
812
- void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
813
- float *distances, idx_t *labels,
814
- float *recons) const
815
- {
816
- idx_t * idx = new idx_t [n * nprobe];
817
- ScopeDeleter<idx_t> del (idx);
818
- float * coarse_dis = new float [n * nprobe];
819
- ScopeDeleter<float> del2 (coarse_dis);
871
+ const size_t nprobe = std::min(nlist, this->nprobe);
872
+ FAISS_THROW_IF_NOT(nprobe > 0);
820
873
 
821
- quantizer->search (n, x, nprobe, coarse_dis, idx);
874
+ idx_t* idx = new idx_t[n * nprobe];
875
+ ScopeDeleter<idx_t> del(idx);
876
+ float* coarse_dis = new float[n * nprobe];
877
+ ScopeDeleter<float> del2(coarse_dis);
822
878
 
823
- invlists->prefetch_lists (idx, n * nprobe);
879
+ quantizer->search(n, x, nprobe, coarse_dis, idx);
880
+
881
+ invlists->prefetch_lists(idx, n * nprobe);
824
882
 
825
883
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
826
884
  // and offset into `codes` for reconstruction
827
- search_preassigned (n, x, k, idx, coarse_dis,
828
- distances, labels, true /* store_pairs */);
885
+ search_preassigned(
886
+ n,
887
+ x,
888
+ k,
889
+ idx,
890
+ coarse_dis,
891
+ distances,
892
+ labels,
893
+ true /* store_pairs */);
829
894
  for (idx_t i = 0; i < n; ++i) {
830
895
  for (idx_t j = 0; j < k; ++j) {
831
896
  idx_t ij = i * k + j;
@@ -835,165 +900,151 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
835
900
  // Fill with NaNs
836
901
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
837
902
  } else {
838
- int list_no = lo_listno (key);
839
- int offset = lo_offset (key);
903
+ int list_no = lo_listno(key);
904
+ int offset = lo_offset(key);
840
905
 
841
906
  // Update label to the actual id
842
- labels[ij] = invlists->get_single_id (list_no, offset);
907
+ labels[ij] = invlists->get_single_id(list_no, offset);
843
908
 
844
- reconstruct_from_offset (list_no, offset, reconstructed);
909
+ reconstruct_from_offset(list_no, offset, reconstructed);
845
910
  }
846
911
  }
847
912
  }
848
913
  }
849
914
 
850
915
  void IndexIVF::reconstruct_from_offset(
851
- int64_t /*list_no*/,
852
- int64_t /*offset*/,
853
- float* /*recons*/) const {
854
- FAISS_THROW_MSG ("reconstruct_from_offset not implemented");
916
+ int64_t /*list_no*/,
917
+ int64_t /*offset*/,
918
+ float* /*recons*/) const {
919
+ FAISS_THROW_MSG("reconstruct_from_offset not implemented");
855
920
  }
856
921
 
857
- void IndexIVF::reset ()
858
- {
859
- direct_map.clear ();
860
- invlists->reset ();
922
+ void IndexIVF::reset() {
923
+ direct_map.clear();
924
+ invlists->reset();
861
925
  ntotal = 0;
862
926
  }
863
927
 
864
-
865
- size_t IndexIVF::remove_ids (const IDSelector & sel)
866
- {
867
- size_t nremove = direct_map.remove_ids (sel, invlists);
928
+ size_t IndexIVF::remove_ids(const IDSelector& sel) {
929
+ size_t nremove = direct_map.remove_ids(sel, invlists);
868
930
  ntotal -= nremove;
869
931
  return nremove;
870
932
  }
871
933
 
872
-
873
- void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
874
- {
875
-
934
+ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
876
935
  if (direct_map.type == DirectMap::Hashtable) {
877
936
  // just remove then add
878
937
  IDSelectorArray sel(n, new_ids);
879
- size_t nremove = remove_ids (sel);
880
- FAISS_THROW_IF_NOT_MSG (nremove == n,
881
- "did not find all entries to remove");
882
- add_with_ids (n, x, new_ids);
938
+ size_t nremove = remove_ids(sel);
939
+ FAISS_THROW_IF_NOT_MSG(
940
+ nremove == n, "did not find all entries to remove");
941
+ add_with_ids(n, x, new_ids);
883
942
  return;
884
943
  }
885
944
 
886
- FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
945
+ FAISS_THROW_IF_NOT(direct_map.type == DirectMap::Array);
887
946
  // here it is more tricky because we don't want to introduce holes
888
947
  // in continuous range of ids
889
948
 
890
- FAISS_THROW_IF_NOT (is_trained);
891
- std::vector<idx_t> assign (n);
892
- quantizer->assign (n, x, assign.data());
893
-
894
- std::vector<uint8_t> flat_codes (n * code_size);
895
- encode_vectors (n, x, assign.data(), flat_codes.data());
949
+ FAISS_THROW_IF_NOT(is_trained);
950
+ std::vector<idx_t> assign(n);
951
+ quantizer->assign(n, x, assign.data());
896
952
 
897
- direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
953
+ std::vector<uint8_t> flat_codes(n * code_size);
954
+ encode_vectors(n, x, assign.data(), flat_codes.data());
898
955
 
956
+ direct_map.update_codes(
957
+ invlists, n, new_ids, assign.data(), flat_codes.data());
899
958
  }
900
959
 
901
-
902
-
903
-
904
- void IndexIVF::train (idx_t n, const float *x)
905
- {
960
+ void IndexIVF::train(idx_t n, const float* x) {
906
961
  if (verbose)
907
- printf ("Training level-1 quantizer\n");
962
+ printf("Training level-1 quantizer\n");
908
963
 
909
- train_q1 (n, x, verbose, metric_type);
964
+ train_q1(n, x, verbose, metric_type);
910
965
 
911
966
  if (verbose)
912
- printf ("Training IVF residual\n");
967
+ printf("Training IVF residual\n");
913
968
 
914
- train_residual (n, x);
969
+ train_residual(n, x);
915
970
  is_trained = true;
916
-
917
971
  }
918
972
 
919
973
  void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
920
- if (verbose)
921
- printf("IndexIVF: no residual training\n");
922
- // does nothing by default
974
+ if (verbose)
975
+ printf("IndexIVF: no residual training\n");
976
+ // does nothing by default
923
977
  }
924
978
 
925
-
926
- void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
927
- {
979
+ void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
928
980
  // minimal sanity checks
929
- FAISS_THROW_IF_NOT (other.d == d);
930
- FAISS_THROW_IF_NOT (other.nlist == nlist);
931
- FAISS_THROW_IF_NOT (other.code_size == code_size);
932
- FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
933
- "can only merge indexes of the same type");
934
- FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
935
- "merge direct_map not implemented");
981
+ FAISS_THROW_IF_NOT(other.d == d);
982
+ FAISS_THROW_IF_NOT(other.nlist == nlist);
983
+ FAISS_THROW_IF_NOT(other.code_size == code_size);
984
+ FAISS_THROW_IF_NOT_MSG(
985
+ typeid(*this) == typeid(other),
986
+ "can only merge indexes of the same type");
987
+ FAISS_THROW_IF_NOT_MSG(
988
+ this->direct_map.no() && other.direct_map.no(),
989
+ "merge direct_map not implemented");
936
990
  }
937
991
 
992
+ void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
993
+ check_compatible_for_merge(other);
938
994
 
939
- void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
940
- {
941
- check_compatible_for_merge (other);
942
-
943
- invlists->merge_from (other.invlists, add_id);
995
+ invlists->merge_from(other.invlists, add_id);
944
996
 
945
997
  ntotal += other.ntotal;
946
998
  other.ntotal = 0;
947
999
  }
948
1000
 
949
-
950
- void IndexIVF::replace_invlists (InvertedLists *il, bool own)
951
- {
1001
+ void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
952
1002
  if (own_invlists) {
953
1003
  delete invlists;
954
1004
  invlists = nullptr;
955
1005
  }
956
1006
  // FAISS_THROW_IF_NOT (ntotal == 0);
957
1007
  if (il) {
958
- FAISS_THROW_IF_NOT (il->nlist == nlist);
959
- FAISS_THROW_IF_NOT (
960
- il->code_size == code_size ||
961
- il->code_size == InvertedLists::INVALID_CODE_SIZE
962
- );
1008
+ FAISS_THROW_IF_NOT(il->nlist == nlist);
1009
+ FAISS_THROW_IF_NOT(
1010
+ il->code_size == code_size ||
1011
+ il->code_size == InvertedLists::INVALID_CODE_SIZE);
963
1012
  }
964
1013
  invlists = il;
965
1014
  own_invlists = own;
966
1015
  }
967
1016
 
968
-
969
- void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
970
- idx_t a1, idx_t a2) const
971
- {
972
-
973
- FAISS_THROW_IF_NOT (nlist == other.nlist);
974
- FAISS_THROW_IF_NOT (code_size == other.code_size);
975
- FAISS_THROW_IF_NOT (other.direct_map.no());
976
- FAISS_THROW_IF_NOT_FMT (
977
- subset_type == 0 || subset_type == 1 || subset_type == 2,
978
- "subset type %d not implemented", subset_type);
1017
+ void IndexIVF::copy_subset_to(
1018
+ IndexIVF& other,
1019
+ int subset_type,
1020
+ idx_t a1,
1021
+ idx_t a2) const {
1022
+ FAISS_THROW_IF_NOT(nlist == other.nlist);
1023
+ FAISS_THROW_IF_NOT(code_size == other.code_size);
1024
+ FAISS_THROW_IF_NOT(other.direct_map.no());
1025
+ FAISS_THROW_IF_NOT_FMT(
1026
+ subset_type == 0 || subset_type == 1 || subset_type == 2,
1027
+ "subset type %d not implemented",
1028
+ subset_type);
979
1029
 
980
1030
  size_t accu_n = 0;
981
1031
  size_t accu_a1 = 0;
982
1032
  size_t accu_a2 = 0;
983
1033
 
984
- InvertedLists *oivf = other.invlists;
1034
+ InvertedLists* oivf = other.invlists;
985
1035
 
986
1036
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
987
- size_t n = invlists->list_size (list_no);
988
- ScopedIds ids_in (invlists, list_no);
1037
+ size_t n = invlists->list_size(list_no);
1038
+ ScopedIds ids_in(invlists, list_no);
989
1039
 
990
1040
  if (subset_type == 0) {
991
1041
  for (idx_t i = 0; i < n; i++) {
992
1042
  idx_t id = ids_in[i];
993
1043
  if (a1 <= id && id < a2) {
994
- oivf->add_entry (list_no,
995
- invlists->get_single_id (list_no, i),
996
- ScopedCodes (invlists, list_no, i).get());
1044
+ oivf->add_entry(
1045
+ list_no,
1046
+ invlists->get_single_id(list_no, i),
1047
+ ScopedCodes(invlists, list_no, i).get());
997
1048
  other.ntotal++;
998
1049
  }
999
1050
  }
@@ -1001,9 +1052,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1001
1052
  for (idx_t i = 0; i < n; i++) {
1002
1053
  idx_t id = ids_in[i];
1003
1054
  if (id % a1 == a2) {
1004
- oivf->add_entry (list_no,
1005
- invlists->get_single_id (list_no, i),
1006
- ScopedCodes (invlists, list_no, i).get());
1055
+ oivf->add_entry(
1056
+ list_no,
1057
+ invlists->get_single_id(list_no, i),
1058
+ ScopedCodes(invlists, list_no, i).get());
1007
1059
  other.ntotal++;
1008
1060
  }
1009
1061
  }
@@ -1016,9 +1068,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1016
1068
  size_t i2 = next_accu_a2 - accu_a2;
1017
1069
 
1018
1070
  for (idx_t i = i1; i < i2; i++) {
1019
- oivf->add_entry (list_no,
1020
- invlists->get_single_id (list_no, i),
1021
- ScopedCodes (invlists, list_no, i).get());
1071
+ oivf->add_entry(
1072
+ list_no,
1073
+ invlists->get_single_id(list_no, i),
1074
+ ScopedCodes(invlists, list_no, i).get());
1022
1075
  }
1023
1076
 
1024
1077
  other.ntotal += i2 - i1;
@@ -1028,48 +1081,87 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1028
1081
  accu_n += n;
1029
1082
  }
1030
1083
  FAISS_ASSERT(accu_n == ntotal);
1031
-
1032
1084
  }
1033
1085
 
1034
-
1035
-
1036
-
1037
- IndexIVF::~IndexIVF()
1038
- {
1086
+ IndexIVF::~IndexIVF() {
1039
1087
  if (own_invlists) {
1040
1088
  delete invlists;
1041
1089
  }
1042
1090
  }
1043
1091
 
1092
+ /*************************************************************************
1093
+ * IndexIVFStats
1094
+ *************************************************************************/
1044
1095
 
1045
- void IndexIVFStats::reset()
1046
- {
1047
- memset ((void*)this, 0, sizeof (*this));
1096
+ void IndexIVFStats::reset() {
1097
+ memset((void*)this, 0, sizeof(*this));
1048
1098
  }
1049
1099
 
1050
- void IndexIVFStats::add (const IndexIVFStats & other)
1051
- {
1100
+ void IndexIVFStats::add(const IndexIVFStats& other) {
1052
1101
  nq += other.nq;
1053
1102
  nlist += other.nlist;
1054
1103
  ndis += other.ndis;
1055
1104
  nheap_updates += other.nheap_updates;
1056
1105
  quantization_time += other.quantization_time;
1057
1106
  search_time += other.search_time;
1058
-
1059
1107
  }
1060
1108
 
1061
-
1062
1109
  IndexIVFStats indexIVF_stats;
1063
1110
 
1064
- void InvertedListScanner::scan_codes_range (size_t ,
1065
- const uint8_t *,
1066
- const idx_t *,
1067
- float ,
1068
- RangeQueryResult &) const
1069
- {
1070
- FAISS_THROW_MSG ("scan_codes_range not implemented");
1111
+ /*************************************************************************
1112
+ * InvertedListScanner
1113
+ *************************************************************************/
1114
+
1115
+ size_t InvertedListScanner::scan_codes(
1116
+ size_t list_size,
1117
+ const uint8_t* codes,
1118
+ const idx_t* ids,
1119
+ float* simi,
1120
+ idx_t* idxi,
1121
+ size_t k) const {
1122
+ size_t nup = 0;
1123
+
1124
+ if (!keep_max) {
1125
+ for (size_t j = 0; j < list_size; j++) {
1126
+ float dis = distance_to_code(codes);
1127
+ if (dis < simi[0]) {
1128
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1129
+ maxheap_replace_top(k, simi, idxi, dis, id);
1130
+ nup++;
1131
+ }
1132
+ codes += code_size;
1133
+ }
1134
+ } else {
1135
+ for (size_t j = 0; j < list_size; j++) {
1136
+ float dis = distance_to_code(codes);
1137
+ if (dis > simi[0]) {
1138
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1139
+ minheap_replace_top(k, simi, idxi, dis, id);
1140
+ nup++;
1141
+ }
1142
+ codes += code_size;
1143
+ }
1144
+ }
1145
+ return nup;
1071
1146
  }
1072
1147
 
1073
-
1148
+ void InvertedListScanner::scan_codes_range(
1149
+ size_t list_size,
1150
+ const uint8_t* codes,
1151
+ const idx_t* ids,
1152
+ float radius,
1153
+ RangeQueryResult& res) const {
1154
+ for (size_t j = 0; j < list_size; j++) {
1155
+ float dis = distance_to_code(codes);
1156
+ bool keep = !keep_max
1157
+ ? dis < radius
1158
+ : dis > radius; // TODO templatize to remove this test
1159
+ if (keep) {
1160
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1161
+ res.add(dis, id);
1162
+ }
1163
+ codes += code_size;
1164
+ }
1165
+ }
1074
1166
 
1075
1167
  } // namespace faiss