faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,17 +9,17 @@
9
9
 
10
10
  #include <faiss/IndexIVFPQ.h>
11
11
 
12
+ #include <stdint.h>
13
+ #include <cassert>
12
14
  #include <cinttypes>
13
15
  #include <cmath>
14
16
  #include <cstdio>
15
- #include <cassert>
16
- #include <stdint.h>
17
17
 
18
18
  #include <algorithm>
19
19
 
20
20
  #include <faiss/utils/Heap.h>
21
- #include <faiss/utils/utils.h>
22
21
  #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/utils.h>
23
23
 
24
24
  #include <faiss/Clustering.h>
25
25
  #include <faiss/IndexFlat.h>
@@ -36,12 +36,15 @@ namespace faiss {
36
36
  * IndexIVFPQ implementation
37
37
  ******************************************/
38
38
 
39
- IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist,
40
- size_t M, size_t nbits_per_idx, MetricType metric):
41
- IndexIVF (quantizer, d, nlist, 0, metric),
42
- pq (d, M, nbits_per_idx)
43
- {
44
- FAISS_THROW_IF_NOT (nbits_per_idx <= 8);
39
+ IndexIVFPQ::IndexIVFPQ(
40
+ Index* quantizer,
41
+ size_t d,
42
+ size_t nlist,
43
+ size_t M,
44
+ size_t nbits_per_idx,
45
+ MetricType metric)
46
+ : IndexIVF(quantizer, d, nlist, 0, metric), pq(d, M, nbits_per_idx) {
47
+ FAISS_THROW_IF_NOT(nbits_per_idx <= 8);
45
48
  code_size = pq.code_size;
46
49
  invlists->code_size = code_size;
47
50
  is_trained = false;
@@ -52,202 +55,197 @@ IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist,
52
55
  polysemous_training = nullptr;
53
56
  do_polysemous_training = false;
54
57
  polysemous_ht = 0;
55
-
56
58
  }
57
59
 
58
-
59
60
  /****************************************************************
60
61
  * training */
61
62
 
62
- void IndexIVFPQ::train_residual (idx_t n, const float *x)
63
- {
64
- train_residual_o (n, x, nullptr);
63
+ void IndexIVFPQ::train_residual(idx_t n, const float* x) {
64
+ train_residual_o(n, x, nullptr);
65
65
  }
66
66
 
67
+ void IndexIVFPQ::train_residual_o(idx_t n, const float* x, float* residuals_2) {
68
+ const float* x_in = x;
67
69
 
68
- void IndexIVFPQ::train_residual_o (idx_t n, const float *x, float *residuals_2)
69
- {
70
- const float * x_in = x;
71
-
72
- x = fvecs_maybe_subsample (
73
- d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub,
74
- x, verbose, pq.cp.seed);
70
+ x = fvecs_maybe_subsample(
71
+ d,
72
+ (size_t*)&n,
73
+ pq.cp.max_points_per_centroid * pq.ksub,
74
+ x,
75
+ verbose,
76
+ pq.cp.seed);
75
77
 
76
- ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
78
+ ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
77
79
 
78
- const float *trainset;
80
+ const float* trainset;
79
81
  ScopeDeleter<float> del_residuals;
80
82
  if (by_residual) {
81
- if(verbose) printf("computing residuals\n");
82
- idx_t * assign = new idx_t [n]; // assignement to coarse centroids
83
- ScopeDeleter<idx_t> del (assign);
84
- quantizer->assign (n, x, assign);
85
- float *residuals = new float [n * d];
86
- del_residuals.set (residuals);
83
+ if (verbose)
84
+ printf("computing residuals\n");
85
+ idx_t* assign = new idx_t[n]; // assignement to coarse centroids
86
+ ScopeDeleter<idx_t> del(assign);
87
+ quantizer->assign(n, x, assign);
88
+ float* residuals = new float[n * d];
89
+ del_residuals.set(residuals);
87
90
  for (idx_t i = 0; i < n; i++)
88
- quantizer->compute_residual (x + i * d, residuals+i*d, assign[i]);
91
+ quantizer->compute_residual(
92
+ x + i * d, residuals + i * d, assign[i]);
89
93
 
90
94
  trainset = residuals;
91
95
  } else {
92
96
  trainset = x;
93
97
  }
94
98
  if (verbose)
95
- printf ("training %zdx%zd product quantizer on %" PRId64 " vectors in %dD\n",
96
- pq.M, pq.ksub, n, d);
99
+ printf("training %zdx%zd product quantizer on %" PRId64
100
+ " vectors in %dD\n",
101
+ pq.M,
102
+ pq.ksub,
103
+ n,
104
+ d);
97
105
  pq.verbose = verbose;
98
- pq.train (n, trainset);
106
+ pq.train(n, trainset);
99
107
 
100
108
  if (do_polysemous_training) {
101
109
  if (verbose)
102
110
  printf("doing polysemous training for PQ\n");
103
111
  PolysemousTraining default_pt;
104
- PolysemousTraining *pt = polysemous_training;
105
- if (!pt) pt = &default_pt;
106
- pt->optimize_pq_for_hamming (pq, n, trainset);
112
+ PolysemousTraining* pt = polysemous_training;
113
+ if (!pt)
114
+ pt = &default_pt;
115
+ pt->optimize_pq_for_hamming(pq, n, trainset);
107
116
  }
108
117
 
109
118
  // prepare second-level residuals for refine PQ
110
119
  if (residuals_2) {
111
- uint8_t *train_codes = new uint8_t [pq.code_size * n];
112
- ScopeDeleter<uint8_t> del (train_codes);
113
- pq.compute_codes (trainset, train_codes, n);
120
+ uint8_t* train_codes = new uint8_t[pq.code_size * n];
121
+ ScopeDeleter<uint8_t> del(train_codes);
122
+ pq.compute_codes(trainset, train_codes, n);
114
123
 
115
124
  for (idx_t i = 0; i < n; i++) {
116
- const float *xx = trainset + i * d;
117
- float * res = residuals_2 + i * d;
118
- pq.decode (train_codes + i * pq.code_size, res);
125
+ const float* xx = trainset + i * d;
126
+ float* res = residuals_2 + i * d;
127
+ pq.decode(train_codes + i * pq.code_size, res);
119
128
  for (int j = 0; j < d; j++)
120
129
  res[j] = xx[j] - res[j];
121
130
  }
122
-
123
131
  }
124
132
 
125
133
  if (by_residual) {
126
- precompute_table ();
134
+ precompute_table();
127
135
  }
128
-
129
136
  }
130
137
 
131
-
132
-
133
-
134
-
135
-
136
138
  /****************************************************************
137
139
  * IVFPQ as codec */
138
140
 
139
-
140
141
  /* produce a binary signature based on the residual vector */
141
- void IndexIVFPQ::encode (idx_t key, const float * x, uint8_t * code) const
142
- {
142
+ void IndexIVFPQ::encode(idx_t key, const float* x, uint8_t* code) const {
143
143
  if (by_residual) {
144
144
  std::vector<float> residual_vec(d);
145
- quantizer->compute_residual (x, residual_vec.data(), key);
146
- pq.compute_code (residual_vec.data(), code);
147
- }
148
- else pq.compute_code (x, code);
145
+ quantizer->compute_residual(x, residual_vec.data(), key);
146
+ pq.compute_code(residual_vec.data(), code);
147
+ } else
148
+ pq.compute_code(x, code);
149
149
  }
150
150
 
151
- void IndexIVFPQ::encode_multiple (size_t n, idx_t *keys,
152
- const float * x, uint8_t * xcodes,
153
- bool compute_keys) const
154
- {
151
+ void IndexIVFPQ::encode_multiple(
152
+ size_t n,
153
+ idx_t* keys,
154
+ const float* x,
155
+ uint8_t* xcodes,
156
+ bool compute_keys) const {
155
157
  if (compute_keys)
156
- quantizer->assign (n, x, keys);
158
+ quantizer->assign(n, x, keys);
157
159
 
158
- encode_vectors (n, x, keys, xcodes);
160
+ encode_vectors(n, x, keys, xcodes);
159
161
  }
160
162
 
161
- void IndexIVFPQ::decode_multiple (size_t n, const idx_t *keys,
162
- const uint8_t * xcodes, float * x) const
163
- {
164
- pq.decode (xcodes, x, n);
163
+ void IndexIVFPQ::decode_multiple(
164
+ size_t n,
165
+ const idx_t* keys,
166
+ const uint8_t* xcodes,
167
+ float* x) const {
168
+ pq.decode(xcodes, x, n);
165
169
  if (by_residual) {
166
- std::vector<float> centroid (d);
170
+ std::vector<float> centroid(d);
167
171
  for (size_t i = 0; i < n; i++) {
168
- quantizer->reconstruct (keys[i], centroid.data());
169
- float *xi = x + i * d;
172
+ quantizer->reconstruct(keys[i], centroid.data());
173
+ float* xi = x + i * d;
170
174
  for (size_t j = 0; j < d; j++) {
171
- xi [j] += centroid [j];
175
+ xi[j] += centroid[j];
172
176
  }
173
177
  }
174
178
  }
175
179
  }
176
180
 
177
-
178
-
179
-
180
181
  /****************************************************************
181
182
  * add */
182
183
 
183
-
184
- void IndexIVFPQ::add_with_ids (idx_t n, const float * x, const idx_t *xids)
185
- {
186
- add_core_o (n, x, xids, nullptr);
184
+ void IndexIVFPQ::add_core(
185
+ idx_t n,
186
+ const float* x,
187
+ const idx_t* xids,
188
+ const idx_t* coarse_idx) {
189
+ add_core_o(n, x, xids, nullptr, coarse_idx);
187
190
  }
188
191
 
189
-
190
- static float * compute_residuals (
191
- const Index *quantizer,
192
- Index::idx_t n, const float* x,
193
- const Index::idx_t *list_nos)
194
- {
192
+ static float* compute_residuals(
193
+ const Index* quantizer,
194
+ Index::idx_t n,
195
+ const float* x,
196
+ const Index::idx_t* list_nos) {
195
197
  size_t d = quantizer->d;
196
- float *residuals = new float [n * d];
198
+ float* residuals = new float[n * d];
197
199
  // TODO: parallelize?
198
200
  for (size_t i = 0; i < n; i++) {
199
201
  if (list_nos[i] < 0)
200
- memset (residuals + i * d, 0, sizeof(*residuals) * d);
202
+ memset(residuals + i * d, 0, sizeof(*residuals) * d);
201
203
  else
202
- quantizer->compute_residual (
203
- x + i * d, residuals + i * d, list_nos[i]);
204
+ quantizer->compute_residual(
205
+ x + i * d, residuals + i * d, list_nos[i]);
204
206
  }
205
207
  return residuals;
206
208
  }
207
209
 
208
- void IndexIVFPQ::encode_vectors(idx_t n, const float* x,
209
- const idx_t *list_nos,
210
- uint8_t * codes,
211
- bool include_listnos) const
212
- {
210
+ void IndexIVFPQ::encode_vectors(
211
+ idx_t n,
212
+ const float* x,
213
+ const idx_t* list_nos,
214
+ uint8_t* codes,
215
+ bool include_listnos) const {
213
216
  if (by_residual) {
214
- float *to_encode = compute_residuals (quantizer, n, x, list_nos);
215
- ScopeDeleter<float> del (to_encode);
216
- pq.compute_codes (to_encode, codes, n);
217
+ float* to_encode = compute_residuals(quantizer, n, x, list_nos);
218
+ ScopeDeleter<float> del(to_encode);
219
+ pq.compute_codes(to_encode, codes, n);
217
220
  } else {
218
- pq.compute_codes (x, codes, n);
221
+ pq.compute_codes(x, codes, n);
219
222
  }
220
223
 
221
224
  if (include_listnos) {
222
225
  size_t coarse_size = coarse_code_size();
223
226
  for (idx_t i = n - 1; i >= 0; i--) {
224
- uint8_t * code = codes + i * (coarse_size + code_size);
225
- memmove (code + coarse_size,
226
- codes + i * code_size, code_size);
227
- encode_listno (list_nos[i], code);
227
+ uint8_t* code = codes + i * (coarse_size + code_size);
228
+ memmove(code + coarse_size, codes + i * code_size, code_size);
229
+ encode_listno(list_nos[i], code);
228
230
  }
229
231
  }
230
232
  }
231
233
 
232
-
233
-
234
- void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
235
- float *x) const
236
- {
237
- size_t coarse_size = coarse_code_size ();
234
+ void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
235
+ size_t coarse_size = coarse_code_size();
238
236
 
239
237
  #pragma omp parallel
240
238
  {
241
- std::vector<float> residual (d);
239
+ std::vector<float> residual(d);
242
240
 
243
241
  #pragma omp for
244
242
  for (idx_t i = 0; i < n; i++) {
245
- const uint8_t *code = codes + i * (code_size + coarse_size);
246
- int64_t list_no = decode_listno (code);
247
- float *xi = x + i * d;
248
- pq.decode (code + coarse_size, xi);
243
+ const uint8_t* code = codes + i * (code_size + coarse_size);
244
+ int64_t list_no = decode_listno(code);
245
+ float* xi = x + i * d;
246
+ pq.decode(code + coarse_size, xi);
249
247
  if (by_residual) {
250
- quantizer->reconstruct (list_no, residual.data());
248
+ quantizer->reconstruct(list_no, residual.data());
251
249
  for (size_t j = 0; j < d; j++) {
252
250
  xi[j] += residual[j];
253
251
  }
@@ -256,120 +254,127 @@ void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
256
254
  }
257
255
  }
258
256
 
259
-
260
- void IndexIVFPQ::add_core_o (idx_t n, const float * x, const idx_t *xids,
261
- float *residuals_2, const idx_t *precomputed_idx)
262
- {
263
-
257
+ void IndexIVFPQ::add_core_o(
258
+ idx_t n,
259
+ const float* x,
260
+ const idx_t* xids,
261
+ float* residuals_2,
262
+ const idx_t* precomputed_idx) {
264
263
  idx_t bs = 32768;
265
264
  if (n > bs) {
266
265
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
267
266
  idx_t i1 = std::min(i0 + bs, n);
268
267
  if (verbose) {
269
- printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64 " / %" PRId64 "\n",
270
- i0, i1, n);
268
+ printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
269
+ " / %" PRId64 "\n",
270
+ i0,
271
+ i1,
272
+ n);
271
273
  }
272
- add_core_o (i1 - i0, x + i0 * d,
273
- xids ? xids + i0 : nullptr,
274
- residuals_2 ? residuals_2 + i0 * d : nullptr,
275
- precomputed_idx ? precomputed_idx + i0 : nullptr);
274
+ add_core_o(
275
+ i1 - i0,
276
+ x + i0 * d,
277
+ xids ? xids + i0 : nullptr,
278
+ residuals_2 ? residuals_2 + i0 * d : nullptr,
279
+ precomputed_idx ? precomputed_idx + i0 : nullptr);
276
280
  }
277
281
  return;
278
282
  }
279
283
 
280
284
  InterruptCallback::check();
281
285
 
282
- direct_map.check_can_add (xids);
286
+ direct_map.check_can_add(xids);
283
287
 
284
- FAISS_THROW_IF_NOT (is_trained);
285
- double t0 = getmillisecs ();
286
- const idx_t * idx;
288
+ FAISS_THROW_IF_NOT(is_trained);
289
+ double t0 = getmillisecs();
290
+ const idx_t* idx;
287
291
  ScopeDeleter<idx_t> del_idx;
288
292
 
289
293
  if (precomputed_idx) {
290
294
  idx = precomputed_idx;
291
295
  } else {
292
- idx_t * idx0 = new idx_t [n];
293
- del_idx.set (idx0);
294
- quantizer->assign (n, x, idx0);
296
+ idx_t* idx0 = new idx_t[n];
297
+ del_idx.set(idx0);
298
+ quantizer->assign(n, x, idx0);
295
299
  idx = idx0;
296
300
  }
297
301
 
298
- double t1 = getmillisecs ();
299
- uint8_t * xcodes = new uint8_t [n * code_size];
300
- ScopeDeleter<uint8_t> del_xcodes (xcodes);
302
+ double t1 = getmillisecs();
303
+ uint8_t* xcodes = new uint8_t[n * code_size];
304
+ ScopeDeleter<uint8_t> del_xcodes(xcodes);
301
305
 
302
- const float *to_encode = nullptr;
306
+ const float* to_encode = nullptr;
303
307
  ScopeDeleter<float> del_to_encode;
304
308
 
305
309
  if (by_residual) {
306
- to_encode = compute_residuals (quantizer, n, x, idx);
307
- del_to_encode.set (to_encode);
310
+ to_encode = compute_residuals(quantizer, n, x, idx);
311
+ del_to_encode.set(to_encode);
308
312
  } else {
309
313
  to_encode = x;
310
314
  }
311
- pq.compute_codes (to_encode, xcodes, n);
315
+ pq.compute_codes(to_encode, xcodes, n);
312
316
 
313
- double t2 = getmillisecs ();
317
+ double t2 = getmillisecs();
314
318
  // TODO: parallelize?
315
319
  size_t n_ignore = 0;
316
320
  for (size_t i = 0; i < n; i++) {
317
321
  idx_t key = idx[i];
318
322
  idx_t id = xids ? xids[i] : ntotal + i;
319
323
  if (key < 0) {
320
- direct_map.add_single_id (id, -1, 0);
321
- n_ignore ++;
324
+ direct_map.add_single_id(id, -1, 0);
325
+ n_ignore++;
322
326
  if (residuals_2)
323
- memset (residuals_2, 0, sizeof(*residuals_2) * d);
327
+ memset(residuals_2, 0, sizeof(*residuals_2) * d);
324
328
  continue;
325
329
  }
326
330
 
327
- uint8_t *code = xcodes + i * code_size;
328
- size_t offset = invlists->add_entry (key, id, code);
331
+ uint8_t* code = xcodes + i * code_size;
332
+ size_t offset = invlists->add_entry(key, id, code);
329
333
 
330
334
  if (residuals_2) {
331
- float *res2 = residuals_2 + i * d;
332
- const float *xi = to_encode + i * d;
333
- pq.decode (code, res2);
335
+ float* res2 = residuals_2 + i * d;
336
+ const float* xi = to_encode + i * d;
337
+ pq.decode(code, res2);
334
338
  for (int j = 0; j < d; j++)
335
339
  res2[j] = xi[j] - res2[j];
336
340
  }
337
341
 
338
- direct_map.add_single_id (id, key, offset);
342
+ direct_map.add_single_id(id, key, offset);
339
343
  }
340
344
 
341
- double t3 = getmillisecs ();
342
- if(verbose) {
345
+ double t3 = getmillisecs();
346
+ if (verbose) {
343
347
  char comment[100] = {0};
344
348
  if (n_ignore > 0)
345
- snprintf (comment, 100, "(%zd vectors ignored)", n_ignore);
349
+ snprintf(comment, 100, "(%zd vectors ignored)", n_ignore);
346
350
  printf(" add_core times: %.3f %.3f %.3f %s\n",
347
- t1 - t0, t2 - t1, t3 - t2, comment);
351
+ t1 - t0,
352
+ t2 - t1,
353
+ t3 - t2,
354
+ comment);
348
355
  }
349
356
  ntotal += n;
350
357
  }
351
358
 
352
-
353
- void IndexIVFPQ::reconstruct_from_offset (int64_t list_no, int64_t offset,
354
- float* recons) const
355
- {
356
- const uint8_t* code = invlists->get_single_code (list_no, offset);
359
+ void IndexIVFPQ::reconstruct_from_offset(
360
+ int64_t list_no,
361
+ int64_t offset,
362
+ float* recons) const {
363
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
357
364
 
358
365
  if (by_residual) {
359
366
  std::vector<float> centroid(d);
360
- quantizer->reconstruct (list_no, centroid.data());
367
+ quantizer->reconstruct(list_no, centroid.data());
361
368
 
362
- pq.decode (code, recons);
369
+ pq.decode(code, recons);
363
370
  for (int i = 0; i < d; ++i) {
364
371
  recons[i] += centroid[i];
365
372
  }
366
373
  } else {
367
- pq.decode (code, recons);
374
+ pq.decode(code, recons);
368
375
  }
369
376
  }
370
377
 
371
-
372
-
373
378
  /// 2G by default, accommodates tables up to PQ32 w/ 65536 centroids
374
379
  size_t precomputed_table_max_bytes = ((size_t)1) << 31;
375
380
 
@@ -403,20 +408,18 @@ size_t precomputed_table_max_bytes = ((size_t)1) << 31;
403
408
  * is faster when the length of the lists is > ksub * M.
404
409
  */
405
410
 
406
- void initialize_IVFPQ_precomputed_table (
407
- int &use_precomputed_table,
408
- const Index *quantizer,
409
- const ProductQuantizer &pq,
410
- AlignedTable<float> & precomputed_table,
411
- bool verbose
412
- )
413
- {
411
+ void initialize_IVFPQ_precomputed_table(
412
+ int& use_precomputed_table,
413
+ const Index* quantizer,
414
+ const ProductQuantizer& pq,
415
+ AlignedTable<float>& precomputed_table,
416
+ bool verbose) {
414
417
  size_t nlist = quantizer->ntotal;
415
418
  size_t d = quantizer->d;
416
419
  FAISS_THROW_IF_NOT(d == pq.d);
417
420
 
418
421
  if (use_precomputed_table == -1) {
419
- precomputed_table.resize (0);
422
+ precomputed_table.resize(0);
420
423
  return;
421
424
  }
422
425
 
@@ -424,23 +427,23 @@ void initialize_IVFPQ_precomputed_table (
424
427
  if (quantizer->metric_type == METRIC_INNER_PRODUCT) {
425
428
  if (verbose) {
426
429
  printf("IndexIVFPQ::precompute_table: precomputed "
427
- "tables not needed for inner product quantizers\n");
430
+ "tables not needed for inner product quantizers\n");
428
431
  }
429
- precomputed_table.resize (0);
432
+ precomputed_table.resize(0);
430
433
  return;
431
434
  }
432
- const MultiIndexQuantizer *miq =
433
- dynamic_cast<const MultiIndexQuantizer *> (quantizer);
435
+ const MultiIndexQuantizer* miq =
436
+ dynamic_cast<const MultiIndexQuantizer*>(quantizer);
434
437
  if (miq && pq.M % miq->pq.M == 0)
435
438
  use_precomputed_table = 2;
436
439
  else {
437
440
  size_t table_size = pq.M * pq.ksub * nlist * sizeof(float);
438
441
  if (table_size > precomputed_table_max_bytes) {
439
442
  if (verbose) {
440
- printf(
441
- "IndexIVFPQ::precompute_table: not precomputing table, "
442
- "it would be too big: %zd bytes (max %zd)\n",
443
- table_size, precomputed_table_max_bytes);
443
+ printf("IndexIVFPQ::precompute_table: not precomputing table, "
444
+ "it would be too big: %zd bytes (max %zd)\n",
445
+ table_size,
446
+ precomputed_table_max_bytes);
444
447
  use_precomputed_table = 0;
445
448
  }
446
449
  return;
@@ -450,80 +453,68 @@ void initialize_IVFPQ_precomputed_table (
450
453
  } // otherwise assume user has set appropriate flag on input
451
454
 
452
455
  if (verbose) {
453
- printf ("precomputing IVFPQ tables type %d\n",
454
- use_precomputed_table);
456
+ printf("precomputing IVFPQ tables type %d\n", use_precomputed_table);
455
457
  }
456
458
 
457
459
  // squared norms of the PQ centroids
458
- std::vector<float> r_norms (pq.M * pq.ksub, NAN);
460
+ std::vector<float> r_norms(pq.M * pq.ksub, NAN);
459
461
  for (int m = 0; m < pq.M; m++)
460
462
  for (int j = 0; j < pq.ksub; j++)
461
- r_norms [m * pq.ksub + j] =
462
- fvec_norm_L2sqr (pq.get_centroids (m, j), pq.dsub);
463
+ r_norms[m * pq.ksub + j] =
464
+ fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
463
465
 
464
466
  if (use_precomputed_table == 1) {
465
-
466
- precomputed_table.resize (nlist * pq.M * pq.ksub);
467
- std::vector<float> centroid (d);
467
+ precomputed_table.resize(nlist * pq.M * pq.ksub);
468
+ std::vector<float> centroid(d);
468
469
 
469
470
  for (size_t i = 0; i < nlist; i++) {
470
- quantizer->reconstruct (i, centroid.data());
471
+ quantizer->reconstruct(i, centroid.data());
471
472
 
472
- float *tab = &precomputed_table[i * pq.M * pq.ksub];
473
- pq.compute_inner_prod_table (centroid.data(), tab);
474
- fvec_madd (pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
473
+ float* tab = &precomputed_table[i * pq.M * pq.ksub];
474
+ pq.compute_inner_prod_table(centroid.data(), tab);
475
+ fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
475
476
  }
476
477
  } else if (use_precomputed_table == 2) {
477
- const MultiIndexQuantizer *miq =
478
- dynamic_cast<const MultiIndexQuantizer *> (quantizer);
479
- FAISS_THROW_IF_NOT (miq);
480
- const ProductQuantizer &cpq = miq->pq;
481
- FAISS_THROW_IF_NOT (pq.M % cpq.M == 0);
478
+ const MultiIndexQuantizer* miq =
479
+ dynamic_cast<const MultiIndexQuantizer*>(quantizer);
480
+ FAISS_THROW_IF_NOT(miq);
481
+ const ProductQuantizer& cpq = miq->pq;
482
+ FAISS_THROW_IF_NOT(pq.M % cpq.M == 0);
482
483
 
483
484
  precomputed_table.resize(cpq.ksub * pq.M * pq.ksub);
484
485
 
485
486
  // reorder PQ centroid table
486
- std::vector<float> centroids (d * cpq.ksub, NAN);
487
+ std::vector<float> centroids(d * cpq.ksub, NAN);
487
488
 
488
489
  for (int m = 0; m < cpq.M; m++) {
489
490
  for (size_t i = 0; i < cpq.ksub; i++) {
490
- memcpy (centroids.data() + i * d + m * cpq.dsub,
491
- cpq.get_centroids (m, i),
492
- sizeof (*centroids.data()) * cpq.dsub);
491
+ memcpy(centroids.data() + i * d + m * cpq.dsub,
492
+ cpq.get_centroids(m, i),
493
+ sizeof(*centroids.data()) * cpq.dsub);
493
494
  }
494
495
  }
495
496
 
496
- pq.compute_inner_prod_tables (cpq.ksub, centroids.data (),
497
- precomputed_table.data ());
497
+ pq.compute_inner_prod_tables(
498
+ cpq.ksub, centroids.data(), precomputed_table.data());
498
499
 
499
500
  for (size_t i = 0; i < cpq.ksub; i++) {
500
- float *tab = &precomputed_table[i * pq.M * pq.ksub];
501
- fvec_madd (pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
501
+ float* tab = &precomputed_table[i * pq.M * pq.ksub];
502
+ fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
502
503
  }
503
-
504
504
  }
505
-
506
505
  }
507
506
 
508
- void IndexIVFPQ::precompute_table ()
509
- {
510
- initialize_IVFPQ_precomputed_table (
511
- use_precomputed_table, quantizer, pq, precomputed_table,
512
- verbose
513
- );
507
+ void IndexIVFPQ::precompute_table() {
508
+ initialize_IVFPQ_precomputed_table(
509
+ use_precomputed_table, quantizer, pq, precomputed_table, verbose);
514
510
  }
515
511
 
516
-
517
-
518
512
  namespace {
519
513
 
520
514
  using idx_t = Index::idx_t;
521
515
 
522
-
523
516
  #define TIC t0 = get_cycles()
524
- #define TOC get_cycles () - t0
525
-
526
-
517
+ #define TOC get_cycles() - t0
527
518
 
528
519
  /** QueryTables manages the various ways of searching an
529
520
  * IndexIVFPQ. The code contains a lot of branches, depending on:
@@ -533,43 +524,42 @@ using idx_t = Index::idx_t;
533
524
  * - polysemous_ht: are we filtering with polysemous codes?
534
525
  */
535
526
  struct QueryTables {
536
-
537
527
  /*****************************************************
538
528
  * General data from the IVFPQ
539
529
  *****************************************************/
540
530
 
541
- const IndexIVFPQ & ivfpq;
542
- const IVFSearchParameters *params;
531
+ const IndexIVFPQ& ivfpq;
532
+ const IVFSearchParameters* params;
543
533
 
544
534
  // copied from IndexIVFPQ for easier access
545
535
  int d;
546
- const ProductQuantizer & pq;
536
+ const ProductQuantizer& pq;
547
537
  MetricType metric_type;
548
538
  bool by_residual;
549
539
  int use_precomputed_table;
550
540
  int polysemous_ht;
551
541
 
552
542
  // pre-allocated data buffers
553
- float * sim_table, * sim_table_2;
554
- float * residual_vec, *decoded_vec;
543
+ float *sim_table, *sim_table_2;
544
+ float *residual_vec, *decoded_vec;
555
545
 
556
546
  // single data buffer
557
547
  std::vector<float> mem;
558
548
 
559
549
  // for table pointers
560
- std::vector<const float *> sim_table_ptrs;
561
-
562
- explicit QueryTables (const IndexIVFPQ & ivfpq,
563
- const IVFSearchParameters *params):
564
- ivfpq(ivfpq),
565
- d(ivfpq.d),
566
- pq (ivfpq.pq),
567
- metric_type (ivfpq.metric_type),
568
- by_residual (ivfpq.by_residual),
569
- use_precomputed_table (ivfpq.use_precomputed_table)
570
- {
571
- mem.resize (pq.ksub * pq.M * 2 + d * 2);
572
- sim_table = mem.data ();
550
+ std::vector<const float*> sim_table_ptrs;
551
+
552
+ explicit QueryTables(
553
+ const IndexIVFPQ& ivfpq,
554
+ const IVFSearchParameters* params)
555
+ : ivfpq(ivfpq),
556
+ d(ivfpq.d),
557
+ pq(ivfpq.pq),
558
+ metric_type(ivfpq.metric_type),
559
+ by_residual(ivfpq.by_residual),
560
+ use_precomputed_table(ivfpq.use_precomputed_table) {
561
+ mem.resize(pq.ksub * pq.M * 2 + d * 2);
562
+ sim_table = mem.data();
573
563
  sim_table_2 = sim_table + pq.ksub * pq.M;
574
564
  residual_vec = sim_table_2 + pq.ksub * pq.M;
575
565
  decoded_vec = residual_vec + d;
@@ -577,14 +567,14 @@ struct QueryTables {
577
567
  // for polysemous
578
568
  polysemous_ht = ivfpq.polysemous_ht;
579
569
  if (auto ivfpq_params =
580
- dynamic_cast<const IVFPQSearchParameters *>(params)) {
570
+ dynamic_cast<const IVFPQSearchParameters*>(params)) {
581
571
  polysemous_ht = ivfpq_params->polysemous_ht;
582
572
  }
583
- if (polysemous_ht != 0) {
584
- q_code.resize (pq.code_size);
573
+ if (polysemous_ht != 0) {
574
+ q_code.resize(pq.code_size);
585
575
  }
586
576
  init_list_cycles = 0;
587
- sim_table_ptrs.resize (pq.M);
577
+ sim_table_ptrs.resize(pq.M);
588
578
  }
589
579
 
590
580
  /*****************************************************
@@ -592,29 +582,29 @@ struct QueryTables {
592
582
  *****************************************************/
593
583
 
594
584
  // field specific to query
595
- const float * qi;
585
+ const float* qi;
596
586
 
597
- // query-specific intialization
598
- void init_query (const float * qi) {
587
+ // query-specific initialization
588
+ void init_query(const float* qi) {
599
589
  this->qi = qi;
600
590
  if (metric_type == METRIC_INNER_PRODUCT)
601
- init_query_IP ();
591
+ init_query_IP();
602
592
  else
603
- init_query_L2 ();
593
+ init_query_L2();
604
594
  if (!by_residual && polysemous_ht != 0)
605
- pq.compute_code (qi, q_code.data());
595
+ pq.compute_code(qi, q_code.data());
606
596
  }
607
597
 
608
- void init_query_IP () {
598
+ void init_query_IP() {
609
599
  // precompute some tables specific to the query qi
610
- pq.compute_inner_prod_table (qi, sim_table);
600
+ pq.compute_inner_prod_table(qi, sim_table);
611
601
  }
612
602
 
613
- void init_query_L2 () {
603
+ void init_query_L2() {
614
604
  if (!by_residual) {
615
- pq.compute_distance_table (qi, sim_table);
605
+ pq.compute_distance_table(qi, sim_table);
616
606
  } else if (use_precomputed_table) {
617
- pq.compute_inner_prod_table (qi, sim_table_2);
607
+ pq.compute_inner_prod_table(qi, sim_table_2);
618
608
  }
619
609
  }
620
610
 
@@ -632,96 +622,95 @@ struct QueryTables {
632
622
  /// once we know the query and the centroid, we can prepare the
633
623
  /// sim_table that will be used for accumulation
634
624
  /// and dis0, the initial value
635
- float precompute_list_tables () {
625
+ float precompute_list_tables() {
636
626
  float dis0 = 0;
637
- uint64_t t0; TIC;
627
+ uint64_t t0;
628
+ TIC;
638
629
  if (by_residual) {
639
630
  if (metric_type == METRIC_INNER_PRODUCT)
640
- dis0 = precompute_list_tables_IP ();
631
+ dis0 = precompute_list_tables_IP();
641
632
  else
642
- dis0 = precompute_list_tables_L2 ();
633
+ dis0 = precompute_list_tables_L2();
643
634
  }
644
635
  init_list_cycles += TOC;
645
636
  return dis0;
646
- }
637
+ }
647
638
 
648
- float precompute_list_table_pointers () {
639
+ float precompute_list_table_pointers() {
649
640
  float dis0 = 0;
650
- uint64_t t0; TIC;
641
+ uint64_t t0;
642
+ TIC;
651
643
  if (by_residual) {
652
644
  if (metric_type == METRIC_INNER_PRODUCT)
653
- FAISS_THROW_MSG ("not implemented");
645
+ FAISS_THROW_MSG("not implemented");
654
646
  else
655
- dis0 = precompute_list_table_pointers_L2 ();
647
+ dis0 = precompute_list_table_pointers_L2();
656
648
  }
657
649
  init_list_cycles += TOC;
658
650
  return dis0;
659
- }
651
+ }
660
652
 
661
653
  /*****************************************************
662
654
  * compute tables for inner prod
663
655
  *****************************************************/
664
656
 
665
- float precompute_list_tables_IP ()
666
- {
657
+ float precompute_list_tables_IP() {
667
658
  // prepare the sim_table that will be used for accumulation
668
659
  // and dis0, the initial value
669
- ivfpq.quantizer->reconstruct (key, decoded_vec);
660
+ ivfpq.quantizer->reconstruct(key, decoded_vec);
670
661
  // decoded_vec = centroid
671
- float dis0 = fvec_inner_product (qi, decoded_vec, d);
662
+ float dis0 = fvec_inner_product(qi, decoded_vec, d);
672
663
 
673
664
  if (polysemous_ht) {
674
665
  for (int i = 0; i < d; i++) {
675
- residual_vec [i] = qi[i] - decoded_vec[i];
666
+ residual_vec[i] = qi[i] - decoded_vec[i];
676
667
  }
677
- pq.compute_code (residual_vec, q_code.data());
668
+ pq.compute_code(residual_vec, q_code.data());
678
669
  }
679
670
  return dis0;
680
671
  }
681
672
 
682
-
683
673
  /*****************************************************
684
674
  * compute tables for L2 distance
685
675
  *****************************************************/
686
676
 
687
- float precompute_list_tables_L2 ()
688
- {
677
+ float precompute_list_tables_L2() {
689
678
  float dis0 = 0;
690
679
 
691
680
  if (use_precomputed_table == 0 || use_precomputed_table == -1) {
692
- ivfpq.quantizer->compute_residual (qi, residual_vec, key);
693
- pq.compute_distance_table (residual_vec, sim_table);
681
+ ivfpq.quantizer->compute_residual(qi, residual_vec, key);
682
+ pq.compute_distance_table(residual_vec, sim_table);
694
683
 
695
684
  if (polysemous_ht != 0) {
696
- pq.compute_code (residual_vec, q_code.data());
685
+ pq.compute_code(residual_vec, q_code.data());
697
686
  }
698
687
 
699
688
  } else if (use_precomputed_table == 1) {
700
689
  dis0 = coarse_dis;
701
690
 
702
- fvec_madd (
691
+ fvec_madd(
703
692
  pq.M * pq.ksub,
704
693
  ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
705
- -2.0, sim_table_2,
706
- sim_table
707
- );
694
+ -2.0,
695
+ sim_table_2,
696
+ sim_table);
708
697
 
709
698
  if (polysemous_ht != 0) {
710
- ivfpq.quantizer->compute_residual (qi, residual_vec, key);
711
- pq.compute_code (residual_vec, q_code.data());
699
+ ivfpq.quantizer->compute_residual(qi, residual_vec, key);
700
+ pq.compute_code(residual_vec, q_code.data());
712
701
  }
713
702
 
714
703
  } else if (use_precomputed_table == 2) {
715
704
  dis0 = coarse_dis;
716
705
 
717
- const MultiIndexQuantizer *miq =
718
- dynamic_cast<const MultiIndexQuantizer *> (ivfpq.quantizer);
719
- FAISS_THROW_IF_NOT (miq);
720
- const ProductQuantizer &cpq = miq->pq;
706
+ const MultiIndexQuantizer* miq =
707
+ dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
708
+ FAISS_THROW_IF_NOT(miq);
709
+ const ProductQuantizer& cpq = miq->pq;
721
710
  int Mf = pq.M / cpq.M;
722
711
 
723
- const float *qtab = sim_table_2; // query-specific table
724
- float *ltab = sim_table; // (output) list-specific table
712
+ const float* qtab = sim_table_2; // query-specific table
713
+ float* ltab = sim_table; // (output) list-specific table
725
714
 
726
715
  long k = key;
727
716
  for (int cm = 0; cm < cpq.M; cm++) {
@@ -730,54 +719,48 @@ struct QueryTables {
730
719
  k >>= cpq.nbits;
731
720
 
732
721
  // get corresponding table
733
- const float *pc = ivfpq.precomputed_table.data() +
734
- (ki * pq.M + cm * Mf) * pq.ksub;
722
+ const float* pc = ivfpq.precomputed_table.data() +
723
+ (ki * pq.M + cm * Mf) * pq.ksub;
735
724
 
736
725
  if (polysemous_ht == 0) {
737
-
738
726
  // sum up with query-specific table
739
- fvec_madd (Mf * pq.ksub,
740
- pc,
741
- -2.0, qtab,
742
- ltab);
727
+ fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
743
728
  ltab += Mf * pq.ksub;
744
729
  qtab += Mf * pq.ksub;
745
730
  } else {
746
731
  for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
747
- q_code[m] = fvec_madd_and_argmin
748
- (pq.ksub, pc, -2, qtab, ltab);
732
+ q_code[m] = fvec_madd_and_argmin(
733
+ pq.ksub, pc, -2, qtab, ltab);
749
734
  pc += pq.ksub;
750
735
  ltab += pq.ksub;
751
736
  qtab += pq.ksub;
752
737
  }
753
738
  }
754
-
755
739
  }
756
740
  }
757
741
 
758
742
  return dis0;
759
743
  }
760
744
 
761
- float precompute_list_table_pointers_L2 ()
762
- {
745
+ float precompute_list_table_pointers_L2() {
763
746
  float dis0 = 0;
764
747
 
765
748
  if (use_precomputed_table == 1) {
766
749
  dis0 = coarse_dis;
767
750
 
768
- const float * s = ivfpq.precomputed_table.data() +
769
- key * pq.ksub * pq.M;
751
+ const float* s =
752
+ ivfpq.precomputed_table.data() + key * pq.ksub * pq.M;
770
753
  for (int m = 0; m < pq.M; m++) {
771
- sim_table_ptrs [m] = s;
754
+ sim_table_ptrs[m] = s;
772
755
  s += pq.ksub;
773
756
  }
774
757
  } else if (use_precomputed_table == 2) {
775
758
  dis0 = coarse_dis;
776
759
 
777
- const MultiIndexQuantizer *miq =
778
- dynamic_cast<const MultiIndexQuantizer *> (ivfpq.quantizer);
779
- FAISS_THROW_IF_NOT (miq);
780
- const ProductQuantizer &cpq = miq->pq;
760
+ const MultiIndexQuantizer* miq =
761
+ dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
762
+ FAISS_THROW_IF_NOT(miq);
763
+ const ProductQuantizer& cpq = miq->pq;
781
764
  int Mf = pq.M / cpq.M;
782
765
 
783
766
  long k = key;
@@ -786,21 +769,21 @@ struct QueryTables {
786
769
  int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
787
770
  k >>= cpq.nbits;
788
771
 
789
- const float *pc = ivfpq.precomputed_table.data() +
790
- (ki * pq.M + cm * Mf) * pq.ksub;
772
+ const float* pc = ivfpq.precomputed_table.data() +
773
+ (ki * pq.M + cm * Mf) * pq.ksub;
791
774
 
792
775
  for (int m = m0; m < m0 + Mf; m++) {
793
- sim_table_ptrs [m] = pc;
776
+ sim_table_ptrs[m] = pc;
794
777
  pc += pq.ksub;
795
778
  }
796
779
  m0 += Mf;
797
780
  }
798
781
  } else {
799
- FAISS_THROW_MSG ("need precomputed tables");
782
+ FAISS_THROW_MSG("need precomputed tables");
800
783
  }
801
784
 
802
785
  if (polysemous_ht) {
803
- FAISS_THROW_MSG ("not implemented");
786
+ FAISS_THROW_MSG("not implemented");
804
787
  // Not clear that it makes sense to implemente this,
805
788
  // because it costs M * ksub, which is what we wanted to
806
789
  // avoid with the tables pointers.
@@ -808,82 +791,72 @@ struct QueryTables {
808
791
 
809
792
  return dis0;
810
793
  }
811
-
812
-
813
794
  };
814
795
 
815
-
816
-
817
- template<class C>
796
+ template <class C>
818
797
  struct KnnSearchResults {
819
798
  idx_t key;
820
- const idx_t *ids;
799
+ const idx_t* ids;
821
800
 
822
801
  // heap params
823
802
  size_t k;
824
- float * heap_sim;
825
- idx_t * heap_ids;
803
+ float* heap_sim;
804
+ idx_t* heap_ids;
826
805
 
827
806
  size_t nup;
828
807
 
829
- inline void add (idx_t j, float dis) {
830
- if (C::cmp (heap_sim[0], dis)) {
831
- idx_t id = ids ? ids[j] : lo_build (key, j);
832
- heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
808
+ inline void add(idx_t j, float dis) {
809
+ if (C::cmp(heap_sim[0], dis)) {
810
+ idx_t id = ids ? ids[j] : lo_build(key, j);
811
+ heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
833
812
  nup++;
834
813
  }
835
814
  }
836
-
837
815
  };
838
816
 
839
- template<class C>
817
+ template <class C>
840
818
  struct RangeSearchResults {
841
819
  idx_t key;
842
- const idx_t *ids;
820
+ const idx_t* ids;
843
821
 
844
822
  // wrapped result structure
845
823
  float radius;
846
- RangeQueryResult & rres;
824
+ RangeQueryResult& rres;
847
825
 
848
- inline void add (idx_t j, float dis) {
849
- if (C::cmp (radius, dis)) {
850
- idx_t id = ids ? ids[j] : lo_build (key, j);
851
- rres.add (dis, id);
826
+ inline void add(idx_t j, float dis) {
827
+ if (C::cmp(radius, dis)) {
828
+ idx_t id = ids ? ids[j] : lo_build(key, j);
829
+ rres.add(dis, id);
852
830
  }
853
831
  }
854
832
  };
855
833
 
856
-
857
-
858
834
  /*****************************************************
859
835
  * Scaning the codes.
860
836
  * The scanning functions call their favorite precompute_*
861
837
  * function to precompute the tables they need.
862
838
  *****************************************************/
863
839
  template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
864
- struct IVFPQScannerT: QueryTables {
865
-
866
- const uint8_t * list_codes;
867
- const IDType * list_ids;
840
+ struct IVFPQScannerT : QueryTables {
841
+ const uint8_t* list_codes;
842
+ const IDType* list_ids;
868
843
  size_t list_size;
869
844
 
870
- IVFPQScannerT (const IndexIVFPQ & ivfpq, const IVFSearchParameters *params):
871
- QueryTables (ivfpq, params)
872
- {
845
+ IVFPQScannerT(const IndexIVFPQ& ivfpq, const IVFSearchParameters* params)
846
+ : QueryTables(ivfpq, params) {
873
847
  assert(METRIC_TYPE == metric_type);
874
848
  }
875
849
 
876
850
  float dis0;
877
851
 
878
- void init_list (idx_t list_no, float coarse_dis,
879
- int mode) {
852
+ void init_list(idx_t list_no, float coarse_dis, int mode) {
880
853
  this->key = list_no;
881
854
  this->coarse_dis = coarse_dis;
882
855
 
883
856
  if (mode == 2) {
884
- dis0 = precompute_list_tables ();
857
+ dis0 = precompute_list_tables();
885
858
  } else if (mode == 1) {
886
- dis0 = precompute_list_table_pointers ();
859
+ dis0 = precompute_list_table_pointers();
887
860
  }
888
861
  }
889
862
 
@@ -892,15 +865,16 @@ struct IVFPQScannerT: QueryTables {
892
865
  *****************************************************/
893
866
 
894
867
  /// version of the scan where we use precomputed tables
895
- template<class SearchResultType>
896
- void scan_list_with_table (size_t ncode, const uint8_t *codes,
897
- SearchResultType & res) const
898
- {
868
+ template <class SearchResultType>
869
+ void scan_list_with_table(
870
+ size_t ncode,
871
+ const uint8_t* codes,
872
+ SearchResultType& res) const {
899
873
  for (size_t j = 0; j < ncode; j++) {
900
874
  PQDecoder decoder(codes, pq.nbits);
901
875
  codes += pq.code_size;
902
876
  float dis = dis0;
903
- const float *tab = sim_table;
877
+ const float* tab = sim_table;
904
878
 
905
879
  for (size_t m = 0; m < pq.M; m++) {
906
880
  dis += tab[decoder.decode()];
@@ -911,43 +885,43 @@ struct IVFPQScannerT: QueryTables {
911
885
  }
912
886
  }
913
887
 
914
-
915
888
  /// tables are not precomputed, but pointers are provided to the
916
889
  /// relevant X_c|x_r tables
917
- template<class SearchResultType>
918
- void scan_list_with_pointer (size_t ncode, const uint8_t *codes,
919
- SearchResultType & res) const
920
- {
890
+ template <class SearchResultType>
891
+ void scan_list_with_pointer(
892
+ size_t ncode,
893
+ const uint8_t* codes,
894
+ SearchResultType& res) const {
921
895
  for (size_t j = 0; j < ncode; j++) {
922
896
  PQDecoder decoder(codes, pq.nbits);
923
897
  codes += pq.code_size;
924
898
 
925
899
  float dis = dis0;
926
- const float *tab = sim_table_2;
900
+ const float* tab = sim_table_2;
927
901
 
928
902
  for (size_t m = 0; m < pq.M; m++) {
929
903
  int ci = decoder.decode();
930
- dis += sim_table_ptrs [m][ci] - 2 * tab [ci];
904
+ dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
931
905
  tab += pq.ksub;
932
906
  }
933
- res.add (j, dis);
907
+ res.add(j, dis);
934
908
  }
935
909
  }
936
910
 
937
-
938
911
  /// nothing is precomputed: access residuals on-the-fly
939
- template<class SearchResultType>
940
- void scan_on_the_fly_dist (size_t ncode, const uint8_t *codes,
941
- SearchResultType &res) const
942
- {
943
- const float *dvec;
912
+ template <class SearchResultType>
913
+ void scan_on_the_fly_dist(
914
+ size_t ncode,
915
+ const uint8_t* codes,
916
+ SearchResultType& res) const {
917
+ const float* dvec;
944
918
  float dis0 = 0;
945
919
  if (by_residual) {
946
920
  if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
947
- ivfpq.quantizer->reconstruct (key, residual_vec);
948
- dis0 = fvec_inner_product (residual_vec, qi, d);
921
+ ivfpq.quantizer->reconstruct(key, residual_vec);
922
+ dis0 = fvec_inner_product(residual_vec, qi, d);
949
923
  } else {
950
- ivfpq.quantizer->compute_residual (qi, residual_vec, key);
924
+ ivfpq.quantizer->compute_residual(qi, residual_vec, key);
951
925
  }
952
926
  dvec = residual_vec;
953
927
  } else {
@@ -956,17 +930,16 @@ struct IVFPQScannerT: QueryTables {
956
930
  }
957
931
 
958
932
  for (size_t j = 0; j < ncode; j++) {
959
-
960
- pq.decode (codes, decoded_vec);
933
+ pq.decode(codes, decoded_vec);
961
934
  codes += pq.code_size;
962
935
 
963
936
  float dis;
964
937
  if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
965
- dis = dis0 + fvec_inner_product (decoded_vec, qi, d);
938
+ dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
966
939
  } else {
967
- dis = fvec_L2sqr (decoded_vec, dvec, d);
940
+ dis = fvec_L2sqr(decoded_vec, dvec, d);
968
941
  }
969
- res.add (j, dis);
942
+ res.add(j, dis);
970
943
  }
971
944
  }
972
945
 
@@ -975,110 +948,99 @@ struct IVFPQScannerT: QueryTables {
975
948
  *****************************************************/
976
949
 
977
950
  template <class HammingComputer, class SearchResultType>
978
- void scan_list_polysemous_hc (
979
- size_t ncode, const uint8_t *codes,
980
- SearchResultType & res) const
981
- {
951
+ void scan_list_polysemous_hc(
952
+ size_t ncode,
953
+ const uint8_t* codes,
954
+ SearchResultType& res) const {
982
955
  int ht = ivfpq.polysemous_ht;
983
956
  size_t n_hamming_pass = 0, nup = 0;
984
957
 
985
958
  int code_size = pq.code_size;
986
959
 
987
- HammingComputer hc (q_code.data(), code_size);
960
+ HammingComputer hc(q_code.data(), code_size);
988
961
 
989
962
  for (size_t j = 0; j < ncode; j++) {
990
- const uint8_t *b_code = codes;
991
- int hd = hc.hamming (b_code);
963
+ const uint8_t* b_code = codes;
964
+ int hd = hc.hamming(b_code);
992
965
  if (hd < ht) {
993
- n_hamming_pass ++;
966
+ n_hamming_pass++;
994
967
  PQDecoder decoder(codes, pq.nbits);
995
968
 
996
969
  float dis = dis0;
997
- const float *tab = sim_table;
970
+ const float* tab = sim_table;
998
971
 
999
972
  for (size_t m = 0; m < pq.M; m++) {
1000
973
  dis += tab[decoder.decode()];
1001
974
  tab += pq.ksub;
1002
975
  }
1003
976
 
1004
- res.add (j, dis);
977
+ res.add(j, dis);
1005
978
  }
1006
979
  codes += code_size;
1007
980
  }
1008
981
  #pragma omp critical
1009
- {
1010
- indexIVFPQ_stats.n_hamming_pass += n_hamming_pass;
1011
- }
982
+ { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
1012
983
  }
1013
984
 
1014
- template<class SearchResultType>
1015
- void scan_list_polysemous (
1016
- size_t ncode, const uint8_t *codes,
1017
- SearchResultType &res) const
1018
- {
985
+ template <class SearchResultType>
986
+ void scan_list_polysemous(
987
+ size_t ncode,
988
+ const uint8_t* codes,
989
+ SearchResultType& res) const {
1019
990
  switch (pq.code_size) {
1020
991
  #define HANDLE_CODE_SIZE(cs) \
1021
- case cs: \
1022
- scan_list_polysemous_hc \
1023
- <HammingComputer ## cs, SearchResultType> \
1024
- (ncode, codes, res); \
1025
- break
1026
- HANDLE_CODE_SIZE(4);
1027
- HANDLE_CODE_SIZE(8);
1028
- HANDLE_CODE_SIZE(16);
1029
- HANDLE_CODE_SIZE(20);
1030
- HANDLE_CODE_SIZE(32);
1031
- HANDLE_CODE_SIZE(64);
992
+ case cs: \
993
+ scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
994
+ ncode, codes, res); \
995
+ break
996
+ HANDLE_CODE_SIZE(4);
997
+ HANDLE_CODE_SIZE(8);
998
+ HANDLE_CODE_SIZE(16);
999
+ HANDLE_CODE_SIZE(20);
1000
+ HANDLE_CODE_SIZE(32);
1001
+ HANDLE_CODE_SIZE(64);
1032
1002
  #undef HANDLE_CODE_SIZE
1033
- default:
1034
- if (pq.code_size % 8 == 0)
1035
- scan_list_polysemous_hc
1036
- <HammingComputerM8, SearchResultType>
1037
- (ncode, codes, res);
1038
- else
1039
- scan_list_polysemous_hc
1040
- <HammingComputerM4, SearchResultType>
1041
- (ncode, codes, res);
1042
- break;
1003
+ default:
1004
+ scan_list_polysemous_hc<
1005
+ HammingComputerDefault,
1006
+ SearchResultType>(ncode, codes, res);
1007
+ break;
1043
1008
  }
1044
1009
  }
1045
-
1046
1010
  };
1047
1011
 
1048
-
1049
1012
  /* We put as many parameters as possible in template. Hopefully the
1050
1013
  * gain in runtime is worth the code bloat. C is the comparator < or
1051
1014
  * >, it is directly related to METRIC_TYPE. precompute_mode is how
1052
1015
  * much we precompute (2 = precompute distance tables, 1 = precompute
1053
1016
  * pointers to distances, 0 = compute distances one by one).
1054
1017
  * Currently only 2 is supported */
1055
- template<MetricType METRIC_TYPE, class C, class PQDecoder>
1056
- struct IVFPQScanner:
1057
- IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>,
1058
- InvertedListScanner
1059
- {
1060
- bool store_pairs;
1018
+ template <MetricType METRIC_TYPE, class C, class PQDecoder>
1019
+ struct IVFPQScanner : IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>,
1020
+ InvertedListScanner {
1061
1021
  int precompute_mode;
1062
1022
 
1063
- IVFPQScanner(const IndexIVFPQ & ivfpq, bool store_pairs,
1064
- int precompute_mode):
1065
- IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
1066
- store_pairs(store_pairs), precompute_mode(precompute_mode)
1067
- {
1023
+ IVFPQScanner(const IndexIVFPQ& ivfpq, bool store_pairs, int precompute_mode)
1024
+ : IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>(
1025
+ ivfpq,
1026
+ nullptr),
1027
+ precompute_mode(precompute_mode) {
1028
+ this->store_pairs = store_pairs;
1068
1029
  }
1069
1030
 
1070
- void set_query (const float *query) override {
1071
- this->init_query (query);
1031
+ void set_query(const float* query) override {
1032
+ this->init_query(query);
1072
1033
  }
1073
1034
 
1074
- void set_list (idx_t list_no, float coarse_dis) override {
1075
- this->init_list (list_no, coarse_dis, precompute_mode);
1035
+ void set_list(idx_t list_no, float coarse_dis) override {
1036
+ this->list_no = list_no;
1037
+ this->init_list(list_no, coarse_dis, precompute_mode);
1076
1038
  }
1077
1039
 
1078
- float distance_to_code (const uint8_t *code) const override {
1040
+ float distance_to_code(const uint8_t* code) const override {
1079
1041
  assert(precompute_mode == 2);
1080
1042
  float dis = this->dis0;
1081
- const float *tab = this->sim_table;
1043
+ const float* tab = this->sim_table;
1082
1044
  PQDecoder decoder(code, this->pq.nbits);
1083
1045
 
1084
1046
  for (size_t m = 0; m < this->pq.M; m++) {
@@ -1088,112 +1050,100 @@ struct IVFPQScanner:
1088
1050
  return dis;
1089
1051
  }
1090
1052
 
1091
- size_t scan_codes (size_t ncode,
1092
- const uint8_t *codes,
1093
- const idx_t *ids,
1094
- float *heap_sim, idx_t *heap_ids,
1095
- size_t k) const override
1096
- {
1053
+ size_t scan_codes(
1054
+ size_t ncode,
1055
+ const uint8_t* codes,
1056
+ const idx_t* ids,
1057
+ float* heap_sim,
1058
+ idx_t* heap_ids,
1059
+ size_t k) const override {
1097
1060
  KnnSearchResults<C> res = {
1098
- /* key */ this->key,
1099
- /* ids */ this->store_pairs ? nullptr : ids,
1100
- /* k */ k,
1101
- /* heap_sim */ heap_sim,
1102
- /* heap_ids */ heap_ids,
1103
- /* nup */ 0
1104
- };
1061
+ /* key */ this->key,
1062
+ /* ids */ this->store_pairs ? nullptr : ids,
1063
+ /* k */ k,
1064
+ /* heap_sim */ heap_sim,
1065
+ /* heap_ids */ heap_ids,
1066
+ /* nup */ 0};
1105
1067
 
1106
1068
  if (this->polysemous_ht > 0) {
1107
1069
  assert(precompute_mode == 2);
1108
- this->scan_list_polysemous (ncode, codes, res);
1070
+ this->scan_list_polysemous(ncode, codes, res);
1109
1071
  } else if (precompute_mode == 2) {
1110
- this->scan_list_with_table (ncode, codes, res);
1072
+ this->scan_list_with_table(ncode, codes, res);
1111
1073
  } else if (precompute_mode == 1) {
1112
- this->scan_list_with_pointer (ncode, codes, res);
1074
+ this->scan_list_with_pointer(ncode, codes, res);
1113
1075
  } else if (precompute_mode == 0) {
1114
- this->scan_on_the_fly_dist (ncode, codes, res);
1076
+ this->scan_on_the_fly_dist(ncode, codes, res);
1115
1077
  } else {
1116
1078
  FAISS_THROW_MSG("bad precomp mode");
1117
1079
  }
1118
1080
  return res.nup;
1119
1081
  }
1120
1082
 
1121
- void scan_codes_range (size_t ncode,
1122
- const uint8_t *codes,
1123
- const idx_t *ids,
1124
- float radius,
1125
- RangeQueryResult & rres) const override
1126
- {
1083
+ void scan_codes_range(
1084
+ size_t ncode,
1085
+ const uint8_t* codes,
1086
+ const idx_t* ids,
1087
+ float radius,
1088
+ RangeQueryResult& rres) const override {
1127
1089
  RangeSearchResults<C> res = {
1128
- /* key */ this->key,
1129
- /* ids */ this->store_pairs ? nullptr : ids,
1130
- /* radius */ radius,
1131
- /* rres */ rres
1132
- };
1090
+ /* key */ this->key,
1091
+ /* ids */ this->store_pairs ? nullptr : ids,
1092
+ /* radius */ radius,
1093
+ /* rres */ rres};
1133
1094
 
1134
1095
  if (this->polysemous_ht > 0) {
1135
1096
  assert(precompute_mode == 2);
1136
- this->scan_list_polysemous (ncode, codes, res);
1097
+ this->scan_list_polysemous(ncode, codes, res);
1137
1098
  } else if (precompute_mode == 2) {
1138
- this->scan_list_with_table (ncode, codes, res);
1099
+ this->scan_list_with_table(ncode, codes, res);
1139
1100
  } else if (precompute_mode == 1) {
1140
- this->scan_list_with_pointer (ncode, codes, res);
1101
+ this->scan_list_with_pointer(ncode, codes, res);
1141
1102
  } else if (precompute_mode == 0) {
1142
- this->scan_on_the_fly_dist (ncode, codes, res);
1103
+ this->scan_on_the_fly_dist(ncode, codes, res);
1143
1104
  } else {
1144
1105
  FAISS_THROW_MSG("bad precomp mode");
1145
1106
  }
1146
-
1147
1107
  }
1148
1108
  };
1149
1109
 
1150
- template<class PQDecoder>
1151
- InvertedListScanner *get_InvertedListScanner1 (const IndexIVFPQ &index,
1152
- bool store_pairs)
1153
- {
1154
-
1155
- if (index.metric_type == METRIC_INNER_PRODUCT) {
1156
- return new IVFPQScanner
1157
- <METRIC_INNER_PRODUCT, CMin<float, idx_t>, PQDecoder>
1158
- (index, store_pairs, 2);
1110
+ template <class PQDecoder>
1111
+ InvertedListScanner* get_InvertedListScanner1(
1112
+ const IndexIVFPQ& index,
1113
+ bool store_pairs) {
1114
+ if (index.metric_type == METRIC_INNER_PRODUCT) {
1115
+ return new IVFPQScanner<
1116
+ METRIC_INNER_PRODUCT,
1117
+ CMin<float, idx_t>,
1118
+ PQDecoder>(index, store_pairs, 2);
1159
1119
  } else if (index.metric_type == METRIC_L2) {
1160
- return new IVFPQScanner
1161
- <METRIC_L2, CMax<float, idx_t>, PQDecoder>
1162
- (index, store_pairs, 2);
1120
+ return new IVFPQScanner<METRIC_L2, CMax<float, idx_t>, PQDecoder>(
1121
+ index, store_pairs, 2);
1163
1122
  }
1164
1123
  return nullptr;
1165
1124
  }
1166
1125
 
1167
-
1168
1126
  } // anonymous namespace
1169
1127
 
1170
- InvertedListScanner *
1171
- IndexIVFPQ::get_InvertedListScanner (bool store_pairs) const
1172
- {
1173
-
1128
+ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
1129
+ bool store_pairs) const {
1174
1130
  if (pq.nbits == 8) {
1175
- return get_InvertedListScanner1<PQDecoder8> (*this, store_pairs);
1131
+ return get_InvertedListScanner1<PQDecoder8>(*this, store_pairs);
1176
1132
  } else if (pq.nbits == 16) {
1177
- return get_InvertedListScanner1<PQDecoder16> (*this, store_pairs);
1133
+ return get_InvertedListScanner1<PQDecoder16>(*this, store_pairs);
1178
1134
  } else {
1179
- return get_InvertedListScanner1<PQDecoderGeneric> (*this, store_pairs);
1135
+ return get_InvertedListScanner1<PQDecoderGeneric>(*this, store_pairs);
1180
1136
  }
1181
1137
  return nullptr;
1182
-
1183
1138
  }
1184
1139
 
1185
-
1186
-
1187
1140
  IndexIVFPQStats indexIVFPQ_stats;
1188
1141
 
1189
- void IndexIVFPQStats::reset () {
1190
- memset (this, 0, sizeof (*this));
1142
+ void IndexIVFPQStats::reset() {
1143
+ memset(this, 0, sizeof(*this));
1191
1144
  }
1192
1145
 
1193
-
1194
-
1195
- IndexIVFPQ::IndexIVFPQ ()
1196
- {
1146
+ IndexIVFPQ::IndexIVFPQ() {
1197
1147
  // initialize some runtime values
1198
1148
  use_precomputed_table = 0;
1199
1149
  scan_table_threshold = 0;
@@ -1202,43 +1152,40 @@ IndexIVFPQ::IndexIVFPQ ()
1202
1152
  polysemous_training = nullptr;
1203
1153
  }
1204
1154
 
1205
-
1206
1155
  struct CodeCmp {
1207
- const uint8_t *tab;
1156
+ const uint8_t* tab;
1208
1157
  size_t code_size;
1209
- bool operator () (int a, int b) const {
1210
- return cmp (a, b) > 0;
1158
+ bool operator()(int a, int b) const {
1159
+ return cmp(a, b) > 0;
1211
1160
  }
1212
- int cmp (int a, int b) const {
1213
- return memcmp (tab + a * code_size, tab + b * code_size,
1214
- code_size);
1161
+ int cmp(int a, int b) const {
1162
+ return memcmp(tab + a * code_size, tab + b * code_size, code_size);
1215
1163
  }
1216
1164
  };
1217
1165
 
1218
-
1219
- size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
1220
- {
1166
+ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
1221
1167
  size_t ngroup = 0;
1222
1168
  lims[0] = 0;
1223
1169
  for (size_t list_no = 0; list_no < nlist; list_no++) {
1224
- size_t n = invlists->list_size (list_no);
1225
- std::vector<int> ord (n);
1226
- for (int i = 0; i < n; i++) ord[i] = i;
1227
- InvertedLists::ScopedCodes codes (invlists, list_no);
1228
- CodeCmp cs = { codes.get(), code_size };
1229
- std::sort (ord.begin(), ord.end(), cs);
1230
-
1231
- InvertedLists::ScopedIds list_ids (invlists, list_no);
1232
- int prev = -1; // all elements from prev to i-1 are equal
1170
+ size_t n = invlists->list_size(list_no);
1171
+ std::vector<int> ord(n);
1172
+ for (int i = 0; i < n; i++)
1173
+ ord[i] = i;
1174
+ InvertedLists::ScopedCodes codes(invlists, list_no);
1175
+ CodeCmp cs = {codes.get(), code_size};
1176
+ std::sort(ord.begin(), ord.end(), cs);
1177
+
1178
+ InvertedLists::ScopedIds list_ids(invlists, list_no);
1179
+ int prev = -1; // all elements from prev to i-1 are equal
1233
1180
  for (int i = 0; i < n; i++) {
1234
- if (prev >= 0 && cs.cmp (ord [prev], ord [i]) == 0) {
1181
+ if (prev >= 0 && cs.cmp(ord[prev], ord[i]) == 0) {
1235
1182
  // same as previous => remember
1236
1183
  if (prev + 1 == i) { // start new group
1237
1184
  ngroup++;
1238
1185
  lims[ngroup] = lims[ngroup - 1];
1239
- dup_ids [lims [ngroup]++] = list_ids [ord [prev]];
1186
+ dup_ids[lims[ngroup]++] = list_ids[ord[prev]];
1240
1187
  }
1241
- dup_ids [lims [ngroup]++] = list_ids [ord [i]];
1188
+ dup_ids[lims[ngroup]++] = list_ids[ord[i]];
1242
1189
  } else { // not same as previous.
1243
1190
  prev = i;
1244
1191
  }
@@ -1247,9 +1194,4 @@ size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
1247
1194
  return ngroup;
1248
1195
  }
1249
1196
 
1250
-
1251
-
1252
-
1253
-
1254
-
1255
1197
  } // namespace faiss