faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -10,15 +10,15 @@
10
10
  #include <faiss/IndexPQ.h>
11
11
 
12
12
  #include <cinttypes>
13
+ #include <cmath>
13
14
  #include <cstddef>
14
- #include <cstring>
15
15
  #include <cstdio>
16
- #include <cmath>
16
+ #include <cstring>
17
17
 
18
18
  #include <algorithm>
19
19
 
20
- #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/impl/AuxIndexStructures.h>
21
+ #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/utils/hamming.h>
23
23
 
24
24
  namespace faiss {
@@ -27,19 +27,17 @@ namespace faiss {
27
27
  * IndexPQ implementation
28
28
  ********************************************************/
29
29
 
30
-
31
- IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
32
- Index(d, metric), pq(d, M, nbits)
33
- {
30
+ IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
31
+ : IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
34
32
  is_trained = false;
35
33
  do_polysemous_training = false;
36
34
  polysemous_ht = nbits * M + 1;
37
35
  search_type = ST_PQ;
38
36
  encode_signs = false;
37
+ code_size = pq.code_size;
39
38
  }
40
39
 
41
- IndexPQ::IndexPQ ()
42
- {
40
+ IndexPQ::IndexPQ() {
43
41
  metric_type = METRIC_L2;
44
42
  is_trained = false;
45
43
  do_polysemous_training = false;
@@ -48,10 +46,8 @@ IndexPQ::IndexPQ ()
48
46
  encode_signs = false;
49
47
  }
50
48
 
51
-
52
- void IndexPQ::train (idx_t n, const float *x)
53
- {
54
- if (!do_polysemous_training) { // standard training
49
+ void IndexPQ::train(idx_t n, const float* x) {
50
+ if (!do_polysemous_training) { // standard training
55
51
  pq.train(n, x);
56
52
  } else {
57
53
  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
@@ -59,92 +55,38 @@ void IndexPQ::train (idx_t n, const float *x)
59
55
  if (ntrain_perm > n / 4)
60
56
  ntrain_perm = n / 4;
61
57
  if (verbose) {
62
- printf ("PQ training on %" PRId64 " points, remains %" PRId64 " points: "
63
- "training polysemous on %s\n",
64
- n - ntrain_perm, ntrain_perm,
65
- ntrain_perm == 0 ? "centroids" : "these");
58
+ printf("PQ training on %" PRId64 " points, remains %" PRId64
59
+ " points: "
60
+ "training polysemous on %s\n",
61
+ n - ntrain_perm,
62
+ ntrain_perm,
63
+ ntrain_perm == 0 ? "centroids" : "these");
66
64
  }
67
65
  pq.train(n - ntrain_perm, x);
68
66
 
69
- polysemous_training.optimize_pq_for_hamming (
70
- pq, ntrain_perm, x + (n - ntrain_perm) * d);
67
+ polysemous_training.optimize_pq_for_hamming(
68
+ pq, ntrain_perm, x + (n - ntrain_perm) * d);
71
69
  }
72
70
  is_trained = true;
73
71
  }
74
72
 
75
-
76
- void IndexPQ::add (idx_t n, const float *x)
77
- {
78
- FAISS_THROW_IF_NOT (is_trained);
79
- codes.resize ((n + ntotal) * pq.code_size);
80
- pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
81
- ntotal += n;
82
- }
83
-
84
-
85
- size_t IndexPQ::remove_ids (const IDSelector & sel)
86
- {
87
- idx_t j = 0;
88
- for (idx_t i = 0; i < ntotal; i++) {
89
- if (sel.is_member (i)) {
90
- // should be removed
91
- } else {
92
- if (i > j) {
93
- memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
94
- }
95
- j++;
96
- }
97
- }
98
- size_t nremove = ntotal - j;
99
- if (nremove > 0) {
100
- ntotal = j;
101
- codes.resize (ntotal * pq.code_size);
102
- }
103
- return nremove;
104
- }
105
-
106
-
107
- void IndexPQ::reset()
108
- {
109
- codes.clear();
110
- ntotal = 0;
111
- }
112
-
113
- void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
114
- {
115
- FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
116
- for (idx_t i = 0; i < ni; i++) {
117
- const uint8_t * code = &codes[(i0 + i) * pq.code_size];
118
- pq.decode (code, recons + i * d);
119
- }
120
- }
121
-
122
-
123
- void IndexPQ::reconstruct (idx_t key, float * recons) const
124
- {
125
- FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
126
- pq.decode (&codes[key * pq.code_size], recons);
127
- }
128
-
129
-
130
73
  namespace {
131
74
 
132
- template<class PQDecoder>
133
- struct PQDistanceComputer: DistanceComputer {
75
+ template <class PQDecoder>
76
+ struct PQDistanceComputer : DistanceComputer {
134
77
  size_t d;
135
78
  MetricType metric;
136
79
  Index::idx_t nb;
137
- const uint8_t *codes;
80
+ const uint8_t* codes;
138
81
  size_t code_size;
139
- const ProductQuantizer & pq;
140
- const float *sdc;
82
+ const ProductQuantizer& pq;
83
+ const float* sdc;
141
84
  std::vector<float> precomputed_table;
142
85
  size_t ndis;
143
86
 
144
- float operator () (idx_t i) override
145
- {
146
- const uint8_t *code = codes + i * code_size;
147
- const float *dt = precomputed_table.data();
87
+ float operator()(idx_t i) override {
88
+ const uint8_t* code = codes + i * code_size;
89
+ const float* dt = precomputed_table.data();
148
90
  PQDecoder decoder(code, pq.nbits);
149
91
  float accu = 0;
150
92
  for (int j = 0; j < pq.M; j++) {
@@ -155,13 +97,12 @@ struct PQDistanceComputer: DistanceComputer {
155
97
  return accu;
156
98
  }
157
99
 
158
- float symmetric_dis(idx_t i, idx_t j) override
159
- {
100
+ float symmetric_dis(idx_t i, idx_t j) override {
160
101
  FAISS_THROW_IF_NOT(sdc);
161
- const float * sdci = sdc;
102
+ const float* sdci = sdc;
162
103
  float accu = 0;
163
- PQDecoder codei (codes + i * code_size, pq.nbits);
164
- PQDecoder codej (codes + j * code_size, pq.nbits);
104
+ PQDecoder codei(codes + i * code_size, pq.nbits);
105
+ PQDecoder codej(codes + j * code_size, pq.nbits);
165
106
 
166
107
  for (int l = 0; l < pq.M; l++) {
167
108
  accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
@@ -171,8 +112,7 @@ struct PQDistanceComputer: DistanceComputer {
171
112
  return accu;
172
113
  }
173
114
 
174
- explicit PQDistanceComputer(const IndexPQ& storage)
175
- : pq(storage.pq) {
115
+ explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
176
116
  precomputed_table.resize(pq.M * pq.ksub);
177
117
  nb = storage.ntotal;
178
118
  d = storage.d;
@@ -187,21 +127,18 @@ struct PQDistanceComputer: DistanceComputer {
187
127
  ndis = 0;
188
128
  }
189
129
 
190
- void set_query(const float *x) override {
130
+ void set_query(const float* x) override {
191
131
  if (metric == METRIC_L2) {
192
132
  pq.compute_distance_table(x, precomputed_table.data());
193
133
  } else {
194
134
  pq.compute_inner_prod_table(x, precomputed_table.data());
195
135
  }
196
-
197
136
  }
198
137
  };
199
138
 
139
+ } // namespace
200
140
 
201
- } // namespace
202
-
203
-
204
- DistanceComputer * IndexPQ::get_distance_computer() const {
141
+ DistanceComputer* IndexPQ::get_distance_computer() const {
205
142
  if (pq.nbits == 8) {
206
143
  return new PQDistanceComputer<PQDecoder8>(*this);
207
144
  } else if (pq.nbits == 16) {
@@ -211,142 +148,142 @@ DistanceComputer * IndexPQ::get_distance_computer() const {
211
148
  }
212
149
  }
213
150
 
214
-
215
151
  /*****************************************
216
152
  * IndexPQ polysemous search routines
217
153
  ******************************************/
218
154
 
155
+ void IndexPQ::search(
156
+ idx_t n,
157
+ const float* x,
158
+ idx_t k,
159
+ float* distances,
160
+ idx_t* labels) const {
161
+ FAISS_THROW_IF_NOT(k > 0);
219
162
 
220
-
221
-
222
-
223
- void IndexPQ::search (idx_t n, const float *x, idx_t k,
224
- float *distances, idx_t *labels) const
225
- {
226
- FAISS_THROW_IF_NOT (is_trained);
227
- if (search_type == ST_PQ) { // Simple PQ search
163
+ FAISS_THROW_IF_NOT(is_trained);
164
+ if (search_type == ST_PQ) { // Simple PQ search
228
165
 
229
166
  if (metric_type == METRIC_L2) {
230
167
  float_maxheap_array_t res = {
231
- size_t(n), size_t(k), labels, distances };
232
- pq.search (x, n, codes.data(), ntotal, &res, true);
168
+ size_t(n), size_t(k), labels, distances};
169
+ pq.search(x, n, codes.data(), ntotal, &res, true);
233
170
  } else {
234
171
  float_minheap_array_t res = {
235
- size_t(n), size_t(k), labels, distances };
236
- pq.search_ip (x, n, codes.data(), ntotal, &res, true);
172
+ size_t(n), size_t(k), labels, distances};
173
+ pq.search_ip(x, n, codes.data(), ntotal, &res, true);
237
174
  }
238
175
  indexPQ_stats.nq += n;
239
176
  indexPQ_stats.ncode += n * ntotal;
240
177
 
241
- } else if (search_type == ST_polysemous ||
242
- search_type == ST_polysemous_generalize) {
243
-
244
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
178
+ } else if (
179
+ search_type == ST_polysemous ||
180
+ search_type == ST_polysemous_generalize) {
181
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
245
182
 
246
- search_core_polysemous (n, x, k, distances, labels);
183
+ search_core_polysemous(n, x, k, distances, labels);
247
184
 
248
185
  } else { // code-to-code distances
249
186
 
250
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
251
- ScopeDeleter<uint8_t> del (q_codes);
252
-
187
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
188
+ ScopeDeleter<uint8_t> del(q_codes);
253
189
 
254
190
  if (!encode_signs) {
255
- pq.compute_codes (x, q_codes, n);
191
+ pq.compute_codes(x, q_codes, n);
256
192
  } else {
257
- FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
258
- memset (q_codes, 0, n * pq.code_size);
193
+ FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
194
+ memset(q_codes, 0, n * pq.code_size);
259
195
  for (size_t i = 0; i < n; i++) {
260
- const float *xi = x + i * d;
261
- uint8_t *code = q_codes + i * pq.code_size;
196
+ const float* xi = x + i * d;
197
+ uint8_t* code = q_codes + i * pq.code_size;
262
198
  for (int j = 0; j < d; j++)
263
- if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
199
+ if (xi[j] > 0)
200
+ code[j >> 3] |= 1 << (j & 7);
264
201
  }
265
202
  }
266
203
 
267
- if (search_type == ST_SDC) {
268
-
204
+ if (search_type == ST_SDC) {
269
205
  float_maxheap_array_t res = {
270
- size_t(n), size_t(k), labels, distances};
206
+ size_t(n), size_t(k), labels, distances};
271
207
 
272
- pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
208
+ pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
273
209
 
274
210
  } else {
275
- int * idistances = new int [n * k];
276
- ScopeDeleter<int> del (idistances);
211
+ int* idistances = new int[n * k];
212
+ ScopeDeleter<int> del(idistances);
277
213
 
278
214
  int_maxheap_array_t res = {
279
- size_t (n), size_t (k), labels, idistances};
215
+ size_t(n), size_t(k), labels, idistances};
280
216
 
281
217
  if (search_type == ST_HE) {
282
-
283
- hammings_knn_hc (&res, q_codes, codes.data(),
284
- ntotal, pq.code_size, true);
218
+ hammings_knn_hc(
219
+ &res,
220
+ q_codes,
221
+ codes.data(),
222
+ ntotal,
223
+ pq.code_size,
224
+ true);
285
225
 
286
226
  } else if (search_type == ST_generalized_HE) {
287
-
288
- generalized_hammings_knn_hc (&res, q_codes, codes.data(),
289
- ntotal, pq.code_size, true);
227
+ generalized_hammings_knn_hc(
228
+ &res,
229
+ q_codes,
230
+ codes.data(),
231
+ ntotal,
232
+ pq.code_size,
233
+ true);
290
234
  }
291
235
 
292
236
  // convert distances to floats
293
237
  for (int i = 0; i < k * n; i++)
294
238
  distances[i] = idistances[i];
295
-
296
239
  }
297
240
 
298
-
299
241
  indexPQ_stats.nq += n;
300
242
  indexPQ_stats.ncode += n * ntotal;
301
243
  }
302
244
  }
303
245
 
304
-
305
-
306
-
307
-
308
- void IndexPQStats::reset()
309
- {
246
+ void IndexPQStats::reset() {
310
247
  nq = ncode = n_hamming_pass = 0;
311
248
  }
312
249
 
313
250
  IndexPQStats indexPQ_stats;
314
251
 
315
-
316
252
  template <class HammingComputer>
317
- static size_t polysemous_inner_loop (
318
- const IndexPQ & index,
319
- const float *dis_table_qi, const uint8_t *q_code,
320
- size_t k, float *heap_dis, int64_t *heap_ids)
321
- {
322
-
253
+ static size_t polysemous_inner_loop(
254
+ const IndexPQ& index,
255
+ const float* dis_table_qi,
256
+ const uint8_t* q_code,
257
+ size_t k,
258
+ float* heap_dis,
259
+ int64_t* heap_ids) {
323
260
  int M = index.pq.M;
324
261
  int code_size = index.pq.code_size;
325
262
  int ksub = index.pq.ksub;
326
263
  size_t ntotal = index.ntotal;
327
264
  int ht = index.polysemous_ht;
328
265
 
329
- const uint8_t *b_code = index.codes.data();
266
+ const uint8_t* b_code = index.codes.data();
330
267
 
331
268
  size_t n_pass_i = 0;
332
269
 
333
- HammingComputer hc (q_code, code_size);
270
+ HammingComputer hc(q_code, code_size);
334
271
 
335
272
  for (int64_t bi = 0; bi < ntotal; bi++) {
336
- int hd = hc.hamming (b_code);
273
+ int hd = hc.hamming(b_code);
337
274
 
338
275
  if (hd < ht) {
339
- n_pass_i ++;
276
+ n_pass_i++;
340
277
 
341
278
  float dis = 0;
342
- const float * dis_table = dis_table_qi;
279
+ const float* dis_table = dis_table_qi;
343
280
  for (int m = 0; m < M; m++) {
344
- dis += dis_table [b_code[m]];
281
+ dis += dis_table[b_code[m]];
345
282
  dis_table += ksub;
346
283
  }
347
284
 
348
285
  if (dis < heap_dis[0]) {
349
- maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
286
+ maxheap_replace_top(k, heap_dis, heap_ids, dis, bi);
350
287
  }
351
288
  }
352
289
  b_code += code_size;
@@ -354,201 +291,201 @@ static size_t polysemous_inner_loop (
354
291
  return n_pass_i;
355
292
  }
356
293
 
294
+ void IndexPQ::search_core_polysemous(
295
+ idx_t n,
296
+ const float* x,
297
+ idx_t k,
298
+ float* distances,
299
+ idx_t* labels) const {
300
+ FAISS_THROW_IF_NOT(k > 0);
357
301
 
358
- void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
359
- float *distances, idx_t *labels) const
360
- {
361
- FAISS_THROW_IF_NOT (pq.nbits == 8);
302
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
362
303
 
363
304
  // PQ distance tables
364
- float * dis_tables = new float [n * pq.ksub * pq.M];
365
- ScopeDeleter<float> del (dis_tables);
366
- pq.compute_distance_tables (n, x, dis_tables);
305
+ float* dis_tables = new float[n * pq.ksub * pq.M];
306
+ ScopeDeleter<float> del(dis_tables);
307
+ pq.compute_distance_tables(n, x, dis_tables);
367
308
 
368
309
  // Hamming embedding queries
369
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
370
- ScopeDeleter<uint8_t> del2 (q_codes);
310
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
311
+ ScopeDeleter<uint8_t> del2(q_codes);
371
312
 
372
313
  if (false) {
373
- pq.compute_codes (x, q_codes, n);
314
+ pq.compute_codes(x, q_codes, n);
374
315
  } else {
375
316
  #pragma omp parallel for
376
317
  for (idx_t qi = 0; qi < n; qi++) {
377
- pq.compute_code_from_distance_table
378
- (dis_tables + qi * pq.M * pq.ksub,
379
- q_codes + qi * pq.code_size);
318
+ pq.compute_code_from_distance_table(
319
+ dis_tables + qi * pq.M * pq.ksub,
320
+ q_codes + qi * pq.code_size);
380
321
  }
381
322
  }
382
323
 
383
324
  size_t n_pass = 0;
384
325
 
385
- #pragma omp parallel for reduction (+: n_pass)
326
+ #pragma omp parallel for reduction(+ : n_pass)
386
327
  for (idx_t qi = 0; qi < n; qi++) {
387
- const uint8_t * q_code = q_codes + qi * pq.code_size;
328
+ const uint8_t* q_code = q_codes + qi * pq.code_size;
388
329
 
389
- const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
330
+ const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
390
331
 
391
- int64_t * heap_ids = labels + qi * k;
392
- float *heap_dis = distances + qi * k;
393
- maxheap_heapify (k, heap_dis, heap_ids);
332
+ int64_t* heap_ids = labels + qi * k;
333
+ float* heap_dis = distances + qi * k;
334
+ maxheap_heapify(k, heap_dis, heap_ids);
394
335
 
395
336
  if (search_type == ST_polysemous) {
396
-
397
337
  switch (pq.code_size) {
398
- case 4:
399
- n_pass += polysemous_inner_loop<HammingComputer4>
400
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
401
- break;
402
- case 8:
403
- n_pass += polysemous_inner_loop<HammingComputer8>
404
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
405
- break;
406
- case 16:
407
- n_pass += polysemous_inner_loop<HammingComputer16>
408
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
409
- break;
410
- case 32:
411
- n_pass += polysemous_inner_loop<HammingComputer32>
412
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
413
- break;
414
- case 20:
415
- n_pass += polysemous_inner_loop<HammingComputer20>
416
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
417
- break;
418
- default:
419
- if (pq.code_size % 8 == 0) {
420
- n_pass += polysemous_inner_loop<HammingComputerM8>
421
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
422
- } else if (pq.code_size % 4 == 0) {
423
- n_pass += polysemous_inner_loop<HammingComputerM4>
424
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
- } else {
426
- FAISS_THROW_FMT(
427
- "code size %zd not supported for polysemous",
428
- pq.code_size);
429
- }
430
- break;
338
+ case 4:
339
+ n_pass += polysemous_inner_loop<HammingComputer4>(
340
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
341
+ break;
342
+ case 8:
343
+ n_pass += polysemous_inner_loop<HammingComputer8>(
344
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
345
+ break;
346
+ case 16:
347
+ n_pass += polysemous_inner_loop<HammingComputer16>(
348
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
349
+ break;
350
+ case 32:
351
+ n_pass += polysemous_inner_loop<HammingComputer32>(
352
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
353
+ break;
354
+ case 20:
355
+ n_pass += polysemous_inner_loop<HammingComputer20>(
356
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
357
+ break;
358
+ default:
359
+ if (pq.code_size % 4 == 0) {
360
+ n_pass += polysemous_inner_loop<HammingComputerDefault>(
361
+ *this,
362
+ dis_table_qi,
363
+ q_code,
364
+ k,
365
+ heap_dis,
366
+ heap_ids);
367
+ } else {
368
+ FAISS_THROW_FMT(
369
+ "code size %zd not supported for polysemous",
370
+ pq.code_size);
371
+ }
372
+ break;
431
373
  }
432
374
  } else {
433
375
  switch (pq.code_size) {
434
- case 8:
435
- n_pass += polysemous_inner_loop<GenHammingComputer8>
436
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
437
- break;
438
- case 16:
439
- n_pass += polysemous_inner_loop<GenHammingComputer16>
440
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
441
- break;
442
- case 32:
443
- n_pass += polysemous_inner_loop<GenHammingComputer32>
444
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
445
- break;
446
- default:
447
- if (pq.code_size % 8 == 0) {
448
- n_pass += polysemous_inner_loop<GenHammingComputerM8>
449
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
450
- } else {
451
- FAISS_THROW_FMT(
452
- "code size %zd not supported for polysemous",
453
- pq.code_size);
454
- }
455
- break;
376
+ case 8:
377
+ n_pass += polysemous_inner_loop<GenHammingComputer8>(
378
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
379
+ break;
380
+ case 16:
381
+ n_pass += polysemous_inner_loop<GenHammingComputer16>(
382
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
383
+ break;
384
+ case 32:
385
+ n_pass += polysemous_inner_loop<GenHammingComputer32>(
386
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
387
+ break;
388
+ default:
389
+ if (pq.code_size % 8 == 0) {
390
+ n_pass += polysemous_inner_loop<GenHammingComputerM8>(
391
+ *this,
392
+ dis_table_qi,
393
+ q_code,
394
+ k,
395
+ heap_dis,
396
+ heap_ids);
397
+ } else {
398
+ FAISS_THROW_FMT(
399
+ "code size %zd not supported for polysemous",
400
+ pq.code_size);
401
+ }
402
+ break;
456
403
  }
457
404
  }
458
- maxheap_reorder (k, heap_dis, heap_ids);
405
+ maxheap_reorder(k, heap_dis, heap_ids);
459
406
  }
460
407
 
461
408
  indexPQ_stats.nq += n;
462
409
  indexPQ_stats.ncode += n * ntotal;
463
410
  indexPQ_stats.n_hamming_pass += n_pass;
464
-
465
-
466
411
  }
467
412
 
468
-
469
413
  /* The standalone codec interface (just remaps to the PQ functions) */
470
- size_t IndexPQ::sa_code_size () const
471
- {
472
- return pq.code_size;
473
- }
474
414
 
475
- void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
476
- {
477
- pq.compute_codes (x, bytes, n);
415
+ void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
416
+ pq.compute_codes(x, bytes, n);
478
417
  }
479
418
 
480
- void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
481
- {
482
- pq.decode (bytes, x, n);
419
+ void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
420
+ pq.decode(bytes, x, n);
483
421
  }
484
422
 
485
-
486
-
487
-
488
423
  /*****************************************
489
424
  * Stats of IndexPQ codes
490
425
  ******************************************/
491
426
 
427
+ void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
428
+ const {
429
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
430
+ ScopeDeleter<uint8_t> del(q_codes);
492
431
 
432
+ pq.compute_codes(x, q_codes, n);
493
433
 
494
-
495
- void IndexPQ::hamming_distance_table (idx_t n, const float *x,
496
- int32_t *dis) const
497
- {
498
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
499
- ScopeDeleter<uint8_t> del (q_codes);
500
-
501
- pq.compute_codes (x, q_codes, n);
502
-
503
- hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
434
+ hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
504
435
  }
505
436
 
506
-
507
- void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
508
- idx_t nb, const float *xb,
509
- int64_t *hist)
510
- {
511
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
512
- FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
513
- FAISS_THROW_IF_NOT (pq.nbits == 8);
437
+ void IndexPQ::hamming_distance_histogram(
438
+ idx_t n,
439
+ const float* x,
440
+ idx_t nb,
441
+ const float* xb,
442
+ int64_t* hist) {
443
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
444
+ FAISS_THROW_IF_NOT(pq.code_size % 8 == 0);
445
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
514
446
 
515
447
  // Hamming embedding queries
516
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
517
- ScopeDeleter <uint8_t> del (q_codes);
518
- pq.compute_codes (x, q_codes, n);
448
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
449
+ ScopeDeleter<uint8_t> del(q_codes);
450
+ pq.compute_codes(x, q_codes, n);
519
451
 
520
- uint8_t * b_codes ;
521
- ScopeDeleter <uint8_t> del_b_codes;
452
+ uint8_t* b_codes;
453
+ ScopeDeleter<uint8_t> del_b_codes;
522
454
 
523
455
  if (xb) {
524
- b_codes = new uint8_t [nb * pq.code_size];
525
- del_b_codes.set (b_codes);
526
- pq.compute_codes (xb, b_codes, nb);
456
+ b_codes = new uint8_t[nb * pq.code_size];
457
+ del_b_codes.set(b_codes);
458
+ pq.compute_codes(xb, b_codes, nb);
527
459
  } else {
528
460
  nb = ntotal;
529
461
  b_codes = codes.data();
530
462
  }
531
463
  int nbits = pq.M * pq.nbits;
532
- memset (hist, 0, sizeof(*hist) * (nbits + 1));
464
+ memset(hist, 0, sizeof(*hist) * (nbits + 1));
533
465
  size_t bs = 256;
534
466
 
535
467
  #pragma omp parallel
536
468
  {
537
- std::vector<int64_t> histi (nbits + 1);
538
- hamdis_t *distances = new hamdis_t [nb * bs];
539
- ScopeDeleter<hamdis_t> del (distances);
469
+ std::vector<int64_t> histi(nbits + 1);
470
+ hamdis_t* distances = new hamdis_t[nb * bs];
471
+ ScopeDeleter<hamdis_t> del(distances);
540
472
  #pragma omp for
541
473
  for (idx_t q0 = 0; q0 < n; q0 += bs) {
542
474
  // printf ("dis stats: %zd/%zd\n", q0, n);
543
475
  size_t q1 = q0 + bs;
544
- if (q1 > n) q1 = n;
476
+ if (q1 > n)
477
+ q1 = n;
545
478
 
546
- hammings (q_codes + q0 * pq.code_size, b_codes,
547
- q1 - q0, nb,
548
- pq.code_size, distances);
479
+ hammings(
480
+ q_codes + q0 * pq.code_size,
481
+ b_codes,
482
+ q1 - q0,
483
+ nb,
484
+ pq.code_size,
485
+ distances);
549
486
 
550
487
  for (size_t i = 0; i < nb * (q1 - q0); i++)
551
- histi [distances [i]]++;
488
+ histi[distances[i]]++;
552
489
  }
553
490
  #pragma omp critical
554
491
  {
@@ -556,28 +493,8 @@ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
556
493
  hist[i] += histi[i];
557
494
  }
558
495
  }
559
-
560
496
  }
561
497
 
562
-
563
-
564
-
565
-
566
-
567
-
568
-
569
-
570
-
571
-
572
-
573
-
574
-
575
-
576
-
577
-
578
-
579
-
580
-
581
498
  /*****************************************
582
499
  * MultiIndexQuantizer
583
500
  ******************************************/
@@ -586,90 +503,87 @@ namespace {
586
503
 
587
504
  template <typename T>
588
505
  struct PreSortedArray {
589
-
590
- const T * x;
506
+ const T* x;
591
507
  int N;
592
508
 
593
- explicit PreSortedArray (int N): N(N) {
594
- }
595
- void init (const T*x) {
509
+ explicit PreSortedArray(int N) : N(N) {}
510
+ void init(const T* x) {
596
511
  this->x = x;
597
512
  }
598
513
  // get smallest value
599
- T get_0 () {
514
+ T get_0() {
600
515
  return x[0];
601
516
  }
602
517
 
603
518
  // get delta between n-smallest and n-1 -smallest
604
- T get_diff (int n) {
519
+ T get_diff(int n) {
605
520
  return x[n] - x[n - 1];
606
521
  }
607
522
 
608
523
  // remap orders counted from smallest to indices in array
609
- int get_ord (int n) {
524
+ int get_ord(int n) {
610
525
  return n;
611
526
  }
612
-
613
527
  };
614
528
 
615
529
  template <typename T>
616
530
  struct ArgSort {
617
- const T * x;
618
- bool operator() (size_t i, size_t j) {
531
+ const T* x;
532
+ bool operator()(size_t i, size_t j) {
619
533
  return x[i] < x[j];
620
534
  }
621
535
  };
622
536
 
623
-
624
537
  /** Array that maintains a permutation of its elements so that the
625
538
  * array's elements are sorted
626
539
  */
627
540
  template <typename T>
628
541
  struct SortedArray {
629
- const T * x;
542
+ const T* x;
630
543
  int N;
631
544
  std::vector<int> perm;
632
545
 
633
- explicit SortedArray (int N) {
546
+ explicit SortedArray(int N) {
634
547
  this->N = N;
635
- perm.resize (N);
548
+ perm.resize(N);
636
549
  }
637
550
 
638
- void init (const T*x) {
551
+ void init(const T* x) {
639
552
  this->x = x;
640
553
  for (int n = 0; n < N; n++)
641
554
  perm[n] = n;
642
- ArgSort<T> cmp = {x };
643
- std::sort (perm.begin(), perm.end(), cmp);
555
+ ArgSort<T> cmp = {x};
556
+ std::sort(perm.begin(), perm.end(), cmp);
644
557
  }
645
558
 
646
559
  // get smallest value
647
- T get_0 () {
560
+ T get_0() {
648
561
  return x[perm[0]];
649
562
  }
650
563
 
651
564
  // get delta between n-smallest and n-1 -smallest
652
- T get_diff (int n) {
565
+ T get_diff(int n) {
653
566
  return x[perm[n]] - x[perm[n - 1]];
654
567
  }
655
568
 
656
569
  // remap orders counted from smallest to indices in array
657
- int get_ord (int n) {
570
+ int get_ord(int n) {
658
571
  return perm[n];
659
572
  }
660
573
  };
661
574
 
662
-
663
-
664
575
  /** Array has n values. Sort the k first ones and copy the other ones
665
576
  * into elements k..n-1
666
577
  */
667
578
  template <class C>
668
- void partial_sort (int k, int n,
669
- const typename C::T * vals, typename C::TI * perm) {
579
+ void partial_sort(
580
+ int k,
581
+ int n,
582
+ const typename C::T* vals,
583
+ typename C::TI* perm) {
670
584
  // insert first k elts in heap
671
585
  for (int i = 1; i < k; i++) {
672
- indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
586
+ indirect_heap_push<C>(i + 1, vals, perm, perm[i]);
673
587
  }
674
588
 
675
589
  // insert next n - k elts in heap
@@ -678,8 +592,8 @@ void partial_sort (int k, int n,
678
592
  typename C::TI top = perm[0];
679
593
 
680
594
  if (C::cmp(vals[top], vals[id])) {
681
- indirect_heap_pop<C> (k, vals, perm);
682
- indirect_heap_push<C> (k, vals, perm, id);
595
+ indirect_heap_pop<C>(k, vals, perm);
596
+ indirect_heap_push<C>(k, vals, perm, id);
683
597
  perm[i] = top;
684
598
  } else {
685
599
  // nothing, elt at i is good where it is.
@@ -689,7 +603,7 @@ void partial_sort (int k, int n,
689
603
  // order the k first elements in heap
690
604
  for (int i = k - 1; i > 0; i--) {
691
605
  typename C::TI top = perm[0];
692
- indirect_heap_pop<C> (i + 1, vals, perm);
606
+ indirect_heap_pop<C>(i + 1, vals, perm);
693
607
  perm[i] = top;
694
608
  }
695
609
  }
@@ -697,69 +611,67 @@ void partial_sort (int k, int n,
697
611
  /** same as SortedArray, but only the k first elements are sorted */
698
612
  template <typename T>
699
613
  struct SemiSortedArray {
700
- const T * x;
614
+ const T* x;
701
615
  int N;
702
616
 
703
617
  // type of the heap: CMax = sort ascending
704
618
  typedef CMax<T, int> HC;
705
619
  std::vector<int> perm;
706
620
 
707
- int k; // k elements are sorted
621
+ int k; // k elements are sorted
708
622
 
709
623
  int initial_k, k_factor;
710
624
 
711
- explicit SemiSortedArray (int N) {
625
+ explicit SemiSortedArray(int N) {
712
626
  this->N = N;
713
- perm.resize (N);
714
- perm.resize (N);
627
+ perm.resize(N);
628
+ perm.resize(N);
715
629
  initial_k = 3;
716
630
  k_factor = 4;
717
631
  }
718
632
 
719
- void init (const T*x) {
633
+ void init(const T* x) {
720
634
  this->x = x;
721
635
  for (int n = 0; n < N; n++)
722
636
  perm[n] = n;
723
637
  k = 0;
724
- grow (initial_k);
638
+ grow(initial_k);
725
639
  }
726
640
 
727
641
  /// grow the sorted part of the array to size next_k
728
- void grow (int next_k) {
642
+ void grow(int next_k) {
729
643
  if (next_k < N) {
730
- partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
644
+ partial_sort<HC>(next_k - k, N - k, x, &perm[k]);
731
645
  k = next_k;
732
646
  } else { // full sort of remainder of array
733
- ArgSort<T> cmp = {x };
734
- std::sort (perm.begin() + k, perm.end(), cmp);
647
+ ArgSort<T> cmp = {x};
648
+ std::sort(perm.begin() + k, perm.end(), cmp);
735
649
  k = N;
736
650
  }
737
651
  }
738
652
 
739
653
  // get smallest value
740
- T get_0 () {
654
+ T get_0() {
741
655
  return x[perm[0]];
742
656
  }
743
657
 
744
658
  // get delta between n-smallest and n-1 -smallest
745
- T get_diff (int n) {
659
+ T get_diff(int n) {
746
660
  if (n >= k) {
747
661
  // want to keep powers of 2 - 1
748
662
  int next_k = (k + 1) * k_factor - 1;
749
- grow (next_k);
663
+ grow(next_k);
750
664
  }
751
665
  return x[perm[n]] - x[perm[n - 1]];
752
666
  }
753
667
 
754
668
  // remap orders counted from smallest to indices in array
755
- int get_ord (int n) {
756
- assert (n < k);
669
+ int get_ord(int n) {
670
+ assert(n < k);
757
671
  return perm[n];
758
672
  }
759
673
  };
760
674
 
761
-
762
-
763
675
  /*****************************************
764
676
  * Find the k smallest sums of M terms, where each term is taken in a
765
677
  * table x of n values.
@@ -779,19 +691,19 @@ struct SemiSortedArray {
779
691
  * occasionally several t's are returned.
780
692
  *
781
693
  * @param x size M * n, values to add up
782
- * @parms k nb of results to retrieve
694
+ * @param k nb of results to retrieve
783
695
  * @param M nb of terms
784
696
  * @param n nb of distinct values
785
697
  * @param sums output, size k, sorted
786
- * @prarm terms output, size k, with encoding as above
698
+ * @param terms output, size k, with encoding as above
787
699
  *
788
700
  ******************************************/
789
701
  template <typename T, class SSA, bool use_seen>
790
702
  struct MinSumK {
791
- int K; ///< nb of sums to return
792
- int M; ///< nb of elements to sum up
703
+ int K; ///< nb of sums to return
704
+ int M; ///< nb of elements to sum up
793
705
  int nbit; ///< nb of bits to encode one entry
794
- int N; ///< nb of possible elements for each of the M terms
706
+ int N; ///< nb of possible elements for each of the M terms
795
707
 
796
708
  /** the heap.
797
709
  * We use a heap to maintain a queue of sums, with the associated
@@ -799,21 +711,20 @@ struct MinSumK {
799
711
  */
800
712
  typedef CMin<T, int64_t> HC;
801
713
  size_t heap_capacity, heap_size;
802
- T *bh_val;
803
- int64_t *bh_ids;
714
+ T* bh_val;
715
+ int64_t* bh_ids;
804
716
 
805
- std::vector <SSA> ssx;
717
+ std::vector<SSA> ssx;
806
718
 
807
719
  // all results get pushed several times. When there are ties, they
808
720
  // are popped interleaved with others, so it is not easy to
809
721
  // identify them. Therefore, this bit array just marks elements
810
722
  // that were seen before.
811
- std::vector <uint8_t> seen;
723
+ std::vector<uint8_t> seen;
812
724
 
813
- MinSumK (int K, int M, int nbit, int N):
814
- K(K), M(M), nbit(nbit), N(N) {
725
+ MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) {
815
726
  heap_capacity = K * M;
816
- assert (N <= (1 << nbit));
727
+ assert(N <= (1 << nbit));
817
728
 
818
729
  // we'll do k steps, each step pushes at most M vals
819
730
  bh_val = new T[heap_capacity];
@@ -821,29 +732,27 @@ struct MinSumK {
821
732
 
822
733
  if (use_seen) {
823
734
  int64_t n_ids = weight(M);
824
- seen.resize ((n_ids + 7) / 8);
735
+ seen.resize((n_ids + 7) / 8);
825
736
  }
826
737
 
827
738
  for (int m = 0; m < M; m++)
828
- ssx.push_back (SSA(N));
829
-
739
+ ssx.push_back(SSA(N));
830
740
  }
831
741
 
832
- int64_t weight (int i) {
742
+ int64_t weight(int i) {
833
743
  return 1 << (i * nbit);
834
744
  }
835
745
 
836
- bool is_seen (int64_t i) {
746
+ bool is_seen(int64_t i) {
837
747
  return (seen[i >> 3] >> (i & 7)) & 1;
838
748
  }
839
749
 
840
- void mark_seen (int64_t i) {
750
+ void mark_seen(int64_t i) {
841
751
  if (use_seen)
842
- seen [i >> 3] |= 1 << (i & 7);
752
+ seen[i >> 3] |= 1 << (i & 7);
843
753
  }
844
754
 
845
- void run (const T *x, int64_t ldx,
846
- T * sums, int64_t * terms) {
755
+ void run(const T* x, int64_t ldx, T* sums, int64_t* terms) {
847
756
  heap_size = 0;
848
757
 
849
758
  for (int m = 0; m < M; m++) {
@@ -854,38 +763,41 @@ struct MinSumK {
854
763
  { // initial result: take min for all elements
855
764
  T sum = 0;
856
765
  terms[0] = 0;
857
- mark_seen (0);
766
+ mark_seen(0);
858
767
  for (int m = 0; m < M; m++) {
859
768
  sum += ssx[m].get_0();
860
769
  }
861
770
  sums[0] = sum;
862
771
  for (int m = 0; m < M; m++) {
863
- heap_push<HC> (++heap_size, bh_val, bh_ids,
864
- sum + ssx[m].get_diff(1),
865
- weight(m));
772
+ heap_push<HC>(
773
+ ++heap_size,
774
+ bh_val,
775
+ bh_ids,
776
+ sum + ssx[m].get_diff(1),
777
+ weight(m));
866
778
  }
867
779
  }
868
780
 
869
781
  for (int k = 1; k < K; k++) {
870
782
  // pop smallest value from heap
871
- if (use_seen) {// skip already seen elements
872
- while (is_seen (bh_ids[0])) {
873
- assert (heap_size > 0);
874
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
783
+ if (use_seen) { // skip already seen elements
784
+ while (is_seen(bh_ids[0])) {
785
+ assert(heap_size > 0);
786
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
875
787
  }
876
788
  }
877
- assert (heap_size > 0);
789
+ assert(heap_size > 0);
878
790
 
879
791
  T sum = sums[k] = bh_val[0];
880
792
  int64_t ti = terms[k] = bh_ids[0];
881
793
 
882
794
  if (use_seen) {
883
- mark_seen (ti);
884
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
795
+ mark_seen(ti);
796
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
885
797
  } else {
886
798
  do {
887
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
888
- } while (heap_size > 0 && bh_ids[0] == ti);
799
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
800
+ } while (heap_size > 0 && bh_ids[0] == ti);
889
801
  }
890
802
 
891
803
  // enqueue followers
@@ -893,9 +805,10 @@ struct MinSumK {
893
805
  for (int m = 0; m < M; m++) {
894
806
  int64_t n = ii & ((1L << nbit) - 1);
895
807
  ii >>= nbit;
896
- if (n + 1 >= N) continue;
808
+ if (n + 1 >= N)
809
+ continue;
897
810
 
898
- enqueue_follower (ti, m, n, sum);
811
+ enqueue_follower(ti, m, n, sum);
899
812
  }
900
813
  }
901
814
 
@@ -922,37 +835,29 @@ struct MinSumK {
922
835
  }
923
836
  }
924
837
 
925
-
926
- void enqueue_follower (int64_t ti, int m, int n, T sum) {
838
+ void enqueue_follower(int64_t ti, int m, int n, T sum) {
927
839
  T next_sum = sum + ssx[m].get_diff(n + 1);
928
840
  int64_t next_ti = ti + weight(m);
929
- heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
841
+ heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti);
930
842
  }
931
843
 
932
- ~MinSumK () {
933
- delete [] bh_ids;
934
- delete [] bh_val;
844
+ ~MinSumK() {
845
+ delete[] bh_ids;
846
+ delete[] bh_val;
935
847
  }
936
848
  };
937
849
 
938
850
  } // anonymous namespace
939
851
 
940
-
941
- MultiIndexQuantizer::MultiIndexQuantizer (int d,
942
- size_t M,
943
- size_t nbits):
944
- Index(d, METRIC_L2), pq(d, M, nbits)
945
- {
852
+ MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits)
853
+ : Index(d, METRIC_L2), pq(d, M, nbits) {
946
854
  is_trained = false;
947
855
  pq.verbose = verbose;
948
856
  }
949
857
 
950
-
951
-
952
- void MultiIndexQuantizer::train(idx_t n, const float *x)
953
- {
858
+ void MultiIndexQuantizer::train(idx_t n, const float* x) {
954
859
  pq.verbose = verbose;
955
- pq.train (n, x);
860
+ pq.train(n, x);
956
861
  is_trained = true;
957
862
  // count virtual elements in index
958
863
  ntotal = 1;
@@ -960,10 +865,16 @@ void MultiIndexQuantizer::train(idx_t n, const float *x)
960
865
  ntotal *= pq.ksub;
961
866
  }
962
867
 
868
+ void MultiIndexQuantizer::search(
869
+ idx_t n,
870
+ const float* x,
871
+ idx_t k,
872
+ float* distances,
873
+ idx_t* labels) const {
874
+ if (n == 0)
875
+ return;
963
876
 
964
- void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
965
- float *distances, idx_t *labels) const {
966
- if (n == 0) return;
877
+ FAISS_THROW_IF_NOT(k > 0);
967
878
 
968
879
  // the allocation just below can be severe...
969
880
  idx_t bs = 32768;
@@ -971,27 +882,28 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
971
882
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
972
883
  idx_t i1 = std::min(i0 + bs, n);
973
884
  if (verbose) {
974
- printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64 " / %" PRId64 "\n",
975
- i0, i1, n);
885
+ printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
886
+ " / %" PRId64 "\n",
887
+ i0,
888
+ i1,
889
+ n);
976
890
  }
977
- search (i1 - i0, x + i0 * d, k,
978
- distances + i0 * k,
979
- labels + i0 * k);
891
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
980
892
  }
981
893
  return;
982
894
  }
983
895
 
984
- float * dis_tables = new float [n * pq.ksub * pq.M];
985
- ScopeDeleter<float> del (dis_tables);
896
+ float* dis_tables = new float[n * pq.ksub * pq.M];
897
+ ScopeDeleter<float> del(dis_tables);
986
898
 
987
- pq.compute_distance_tables (n, x, dis_tables);
899
+ pq.compute_distance_tables(n, x, dis_tables);
988
900
 
989
901
  if (k == 1) {
990
902
  // simple version that just finds the min in each table
991
903
 
992
904
  #pragma omp parallel for
993
905
  for (int i = 0; i < n; i++) {
994
- const float * dis_table = dis_tables + i * pq.ksub * pq.M;
906
+ const float* dis_table = dis_tables + i * pq.ksub * pq.M;
995
907
  float dis = 0;
996
908
  idx_t label = 0;
997
909
 
@@ -1010,32 +922,27 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
1010
922
  dis_table += pq.ksub;
1011
923
  }
1012
924
 
1013
- distances [i] = dis;
1014
- labels [i] = label;
925
+ distances[i] = dis;
926
+ labels[i] = label;
1015
927
  }
1016
928
 
1017
-
1018
929
  } else {
1019
-
1020
- #pragma omp parallel if(n > 1)
930
+ #pragma omp parallel if (n > 1)
1021
931
  {
1022
- MinSumK <float, SemiSortedArray<float>, false>
1023
- msk(k, pq.M, pq.nbits, pq.ksub);
932
+ MinSumK<float, SemiSortedArray<float>, false> msk(
933
+ k, pq.M, pq.nbits, pq.ksub);
1024
934
  #pragma omp for
1025
935
  for (int i = 0; i < n; i++) {
1026
- msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
1027
- distances + i * k, labels + i * k);
1028
-
936
+ msk.run(dis_tables + i * pq.ksub * pq.M,
937
+ pq.ksub,
938
+ distances + i * k,
939
+ labels + i * k);
1029
940
  }
1030
941
  }
1031
942
  }
1032
-
1033
943
  }
1034
944
 
1035
-
1036
- void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1037
- {
1038
-
945
+ void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
1039
946
  int64_t jj = key;
1040
947
  for (int m = 0; m < pq.M; m++) {
1041
948
  int64_t n = jj & ((1L << pq.nbits) - 1);
@@ -1046,65 +953,53 @@ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1046
953
  }
1047
954
 
1048
955
  void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
1049
- FAISS_THROW_MSG(
1050
- "This index has virtual elements, "
1051
- "it does not support add");
956
+ FAISS_THROW_MSG(
957
+ "This index has virtual elements, "
958
+ "it does not support add");
1052
959
  }
1053
960
 
1054
- void MultiIndexQuantizer::reset ()
1055
- {
1056
- FAISS_THROW_MSG ( "This index has virtual elements, "
1057
- "it does not support reset");
961
+ void MultiIndexQuantizer::reset() {
962
+ FAISS_THROW_MSG(
963
+ "This index has virtual elements, "
964
+ "it does not support reset");
1058
965
  }
1059
966
 
1060
-
1061
-
1062
-
1063
-
1064
-
1065
-
1066
-
1067
-
1068
-
1069
967
  /*****************************************
1070
968
  * MultiIndexQuantizer2
1071
969
  ******************************************/
1072
970
 
1073
-
1074
-
1075
- MultiIndexQuantizer2::MultiIndexQuantizer2 (
1076
- int d, size_t M, size_t nbits,
1077
- Index **indexes):
1078
- MultiIndexQuantizer (d, M, nbits)
1079
- {
1080
- assign_indexes.resize (M);
971
+ MultiIndexQuantizer2::MultiIndexQuantizer2(
972
+ int d,
973
+ size_t M,
974
+ size_t nbits,
975
+ Index** indexes)
976
+ : MultiIndexQuantizer(d, M, nbits) {
977
+ assign_indexes.resize(M);
1081
978
  for (int i = 0; i < M; i++) {
1082
979
  FAISS_THROW_IF_NOT_MSG(
1083
- indexes[i]->d == pq.dsub,
1084
- "Provided sub-index has incorrect size");
980
+ indexes[i]->d == pq.dsub,
981
+ "Provided sub-index has incorrect size");
1085
982
  assign_indexes[i] = indexes[i];
1086
983
  }
1087
984
  own_fields = false;
1088
985
  }
1089
986
 
1090
- MultiIndexQuantizer2::MultiIndexQuantizer2 (
1091
- int d, size_t nbits,
1092
- Index *assign_index_0,
1093
- Index *assign_index_1):
1094
- MultiIndexQuantizer (d, 2, nbits)
1095
- {
987
+ MultiIndexQuantizer2::MultiIndexQuantizer2(
988
+ int d,
989
+ size_t nbits,
990
+ Index* assign_index_0,
991
+ Index* assign_index_1)
992
+ : MultiIndexQuantizer(d, 2, nbits) {
1096
993
  FAISS_THROW_IF_NOT_MSG(
1097
- assign_index_0->d == pq.dsub &&
1098
- assign_index_1->d == pq.dsub,
994
+ assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub,
1099
995
  "Provided sub-index has incorrect size");
1100
- assign_indexes.resize (2);
1101
- assign_indexes [0] = assign_index_0;
1102
- assign_indexes [1] = assign_index_1;
996
+ assign_indexes.resize(2);
997
+ assign_indexes[0] = assign_index_0;
998
+ assign_indexes[1] = assign_index_1;
1103
999
  own_fields = false;
1104
1000
  }
1105
1001
 
1106
- void MultiIndexQuantizer2::train(idx_t n, const float* x)
1107
- {
1002
+ void MultiIndexQuantizer2::train(idx_t n, const float* x) {
1108
1003
  MultiIndexQuantizer::train(n, x);
1109
1004
  // add centroids to sub-indexes
1110
1005
  for (int i = 0; i < pq.M; i++) {
@@ -1112,15 +1007,17 @@ void MultiIndexQuantizer2::train(idx_t n, const float* x)
1112
1007
  }
1113
1008
  }
1114
1009
 
1115
-
1116
1010
  void MultiIndexQuantizer2::search(
1117
- idx_t n, const float* x, idx_t K,
1118
- float* distances, idx_t* labels) const
1119
- {
1120
-
1121
- if (n == 0) return;
1011
+ idx_t n,
1012
+ const float* x,
1013
+ idx_t K,
1014
+ float* distances,
1015
+ idx_t* labels) const {
1016
+ if (n == 0)
1017
+ return;
1122
1018
 
1123
1019
  int k2 = std::min(K, int64_t(pq.ksub));
1020
+ FAISS_THROW_IF_NOT(k2);
1124
1021
 
1125
1022
  int64_t M = pq.M;
1126
1023
  int64_t dsub = pq.dsub, ksub = pq.ksub;
@@ -1131,8 +1028,8 @@ void MultiIndexQuantizer2::search(
1131
1028
  std::vector<float> xsub(n * dsub);
1132
1029
 
1133
1030
  for (int m = 0; m < M; m++) {
1134
- float *xdest = xsub.data();
1135
- const float *xsrc = x + m * dsub;
1031
+ float* xdest = xsub.data();
1032
+ const float* xsrc = x + m * dsub;
1136
1033
  for (int j = 0; j < n; j++) {
1137
1034
  memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1138
1035
  xsrc += d;
@@ -1140,14 +1037,12 @@ void MultiIndexQuantizer2::search(
1140
1037
  }
1141
1038
 
1142
1039
  assign_indexes[m]->search(
1143
- n, xsub.data(), k2,
1144
- &sub_dis[k2 * n * m],
1145
- &sub_ids[k2 * n * m]);
1040
+ n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]);
1146
1041
  }
1147
1042
 
1148
1043
  if (K == 1) {
1149
1044
  // simple version that just finds the min in each table
1150
- assert (k2 == 1);
1045
+ assert(k2 == 1);
1151
1046
 
1152
1047
  for (int i = 0; i < n; i++) {
1153
1048
  float dis = 0;
@@ -1159,30 +1054,28 @@ void MultiIndexQuantizer2::search(
1159
1054
  dis += vmin;
1160
1055
  label |= lmin << (m * pq.nbits);
1161
1056
  }
1162
- distances [i] = dis;
1163
- labels [i] = label;
1057
+ distances[i] = dis;
1058
+ labels[i] = label;
1164
1059
  }
1165
1060
 
1166
1061
  } else {
1167
-
1168
- #pragma omp parallel if(n > 1)
1062
+ #pragma omp parallel if (n > 1)
1169
1063
  {
1170
- MinSumK <float, PreSortedArray<float>, false>
1171
- msk(K, pq.M, pq.nbits, k2);
1064
+ MinSumK<float, PreSortedArray<float>, false> msk(
1065
+ K, pq.M, pq.nbits, k2);
1172
1066
  #pragma omp for
1173
1067
  for (int i = 0; i < n; i++) {
1174
- idx_t *li = labels + i * K;
1175
- msk.run (&sub_dis[i * k2], k2 * n,
1176
- distances + i * K, li);
1068
+ idx_t* li = labels + i * K;
1069
+ msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li);
1177
1070
 
1178
1071
  // remap ids
1179
1072
 
1180
- const idx_t *idmap0 = sub_ids.data() + i * k2;
1073
+ const idx_t* idmap0 = sub_ids.data() + i * k2;
1181
1074
  int64_t ld_idmap = k2 * n;
1182
1075
  int64_t mask1 = ksub - 1L;
1183
1076
 
1184
1077
  for (int k = 0; k < K; k++) {
1185
- const idx_t *idmap = idmap0;
1078
+ const idx_t* idmap = idmap0;
1186
1079
  int64_t vin = li[k];
1187
1080
  int64_t vout = 0;
1188
1081
  int bs = 0;
@@ -1200,5 +1093,4 @@ void MultiIndexQuantizer2::search(
1200
1093
  }
1201
1094
  }
1202
1095
 
1203
-
1204
1096
  } // namespace faiss