faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,15 +10,15 @@
10
10
  #include <faiss/IndexPQ.h>
11
11
 
12
12
  #include <cinttypes>
13
+ #include <cmath>
13
14
  #include <cstddef>
14
- #include <cstring>
15
15
  #include <cstdio>
16
- #include <cmath>
16
+ #include <cstring>
17
17
 
18
18
  #include <algorithm>
19
19
 
20
- #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/impl/AuxIndexStructures.h>
21
+ #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/utils/hamming.h>
23
23
 
24
24
  namespace faiss {
@@ -27,10 +27,8 @@ namespace faiss {
27
27
  * IndexPQ implementation
28
28
  ********************************************************/
29
29
 
30
-
31
- IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
32
- Index(d, metric), pq(d, M, nbits)
33
- {
30
+ IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
31
+ : Index(d, metric), pq(d, M, nbits) {
34
32
  is_trained = false;
35
33
  do_polysemous_training = false;
36
34
  polysemous_ht = nbits * M + 1;
@@ -38,8 +36,7 @@ IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
38
36
  encode_signs = false;
39
37
  }
40
38
 
41
- IndexPQ::IndexPQ ()
42
- {
39
+ IndexPQ::IndexPQ() {
43
40
  metric_type = METRIC_L2;
44
41
  is_trained = false;
45
42
  do_polysemous_training = false;
@@ -48,10 +45,8 @@ IndexPQ::IndexPQ ()
48
45
  encode_signs = false;
49
46
  }
50
47
 
51
-
52
- void IndexPQ::train (idx_t n, const float *x)
53
- {
54
- if (!do_polysemous_training) { // standard training
48
+ void IndexPQ::train(idx_t n, const float* x) {
49
+ if (!do_polysemous_training) { // standard training
55
50
  pq.train(n, x);
56
51
  } else {
57
52
  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
@@ -59,38 +54,38 @@ void IndexPQ::train (idx_t n, const float *x)
59
54
  if (ntrain_perm > n / 4)
60
55
  ntrain_perm = n / 4;
61
56
  if (verbose) {
62
- printf ("PQ training on %" PRId64 " points, remains %" PRId64 " points: "
63
- "training polysemous on %s\n",
64
- n - ntrain_perm, ntrain_perm,
65
- ntrain_perm == 0 ? "centroids" : "these");
57
+ printf("PQ training on %" PRId64 " points, remains %" PRId64
58
+ " points: "
59
+ "training polysemous on %s\n",
60
+ n - ntrain_perm,
61
+ ntrain_perm,
62
+ ntrain_perm == 0 ? "centroids" : "these");
66
63
  }
67
64
  pq.train(n - ntrain_perm, x);
68
65
 
69
- polysemous_training.optimize_pq_for_hamming (
70
- pq, ntrain_perm, x + (n - ntrain_perm) * d);
66
+ polysemous_training.optimize_pq_for_hamming(
67
+ pq, ntrain_perm, x + (n - ntrain_perm) * d);
71
68
  }
72
69
  is_trained = true;
73
70
  }
74
71
 
75
-
76
- void IndexPQ::add (idx_t n, const float *x)
77
- {
78
- FAISS_THROW_IF_NOT (is_trained);
79
- codes.resize ((n + ntotal) * pq.code_size);
80
- pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
72
+ void IndexPQ::add(idx_t n, const float* x) {
73
+ FAISS_THROW_IF_NOT(is_trained);
74
+ codes.resize((n + ntotal) * pq.code_size);
75
+ pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
81
76
  ntotal += n;
82
77
  }
83
78
 
84
-
85
- size_t IndexPQ::remove_ids (const IDSelector & sel)
86
- {
79
+ size_t IndexPQ::remove_ids(const IDSelector& sel) {
87
80
  idx_t j = 0;
88
81
  for (idx_t i = 0; i < ntotal; i++) {
89
- if (sel.is_member (i)) {
82
+ if (sel.is_member(i)) {
90
83
  // should be removed
91
84
  } else {
92
85
  if (i > j) {
93
- memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
86
+ memmove(&codes[pq.code_size * j],
87
+ &codes[pq.code_size * i],
88
+ pq.code_size);
94
89
  }
95
90
  j++;
96
91
  }
@@ -98,53 +93,46 @@ size_t IndexPQ::remove_ids (const IDSelector & sel)
98
93
  size_t nremove = ntotal - j;
99
94
  if (nremove > 0) {
100
95
  ntotal = j;
101
- codes.resize (ntotal * pq.code_size);
96
+ codes.resize(ntotal * pq.code_size);
102
97
  }
103
98
  return nremove;
104
99
  }
105
100
 
106
-
107
- void IndexPQ::reset()
108
- {
101
+ void IndexPQ::reset() {
109
102
  codes.clear();
110
103
  ntotal = 0;
111
104
  }
112
105
 
113
- void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
114
- {
115
- FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
106
+ void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
107
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
116
108
  for (idx_t i = 0; i < ni; i++) {
117
- const uint8_t * code = &codes[(i0 + i) * pq.code_size];
118
- pq.decode (code, recons + i * d);
109
+ const uint8_t* code = &codes[(i0 + i) * pq.code_size];
110
+ pq.decode(code, recons + i * d);
119
111
  }
120
112
  }
121
113
 
122
-
123
- void IndexPQ::reconstruct (idx_t key, float * recons) const
124
- {
125
- FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
126
- pq.decode (&codes[key * pq.code_size], recons);
114
+ void IndexPQ::reconstruct(idx_t key, float* recons) const {
115
+ FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
116
+ pq.decode(&codes[key * pq.code_size], recons);
127
117
  }
128
118
 
129
-
130
119
  namespace {
131
120
 
132
- template<class PQDecoder>
133
- struct PQDistanceComputer: DistanceComputer {
121
+ template <class PQDecoder>
122
+ struct PQDistanceComputer : DistanceComputer {
134
123
  size_t d;
135
124
  MetricType metric;
136
125
  Index::idx_t nb;
137
- const uint8_t *codes;
126
+ const uint8_t* codes;
138
127
  size_t code_size;
139
- const ProductQuantizer & pq;
140
- const float *sdc;
128
+ const ProductQuantizer& pq;
129
+ const float* sdc;
141
130
  std::vector<float> precomputed_table;
142
131
  size_t ndis;
143
132
 
144
- float operator () (idx_t i) override
145
- {
146
- const uint8_t *code = codes + i * code_size;
147
- const float *dt = precomputed_table.data();
133
+ float operator()(idx_t i) override {
134
+ const uint8_t* code = codes + i * code_size;
135
+ const float* dt = precomputed_table.data();
148
136
  PQDecoder decoder(code, pq.nbits);
149
137
  float accu = 0;
150
138
  for (int j = 0; j < pq.M; j++) {
@@ -155,13 +143,12 @@ struct PQDistanceComputer: DistanceComputer {
155
143
  return accu;
156
144
  }
157
145
 
158
- float symmetric_dis(idx_t i, idx_t j) override
159
- {
146
+ float symmetric_dis(idx_t i, idx_t j) override {
160
147
  FAISS_THROW_IF_NOT(sdc);
161
- const float * sdci = sdc;
148
+ const float* sdci = sdc;
162
149
  float accu = 0;
163
- PQDecoder codei (codes + i * code_size, pq.nbits);
164
- PQDecoder codej (codes + j * code_size, pq.nbits);
150
+ PQDecoder codei(codes + i * code_size, pq.nbits);
151
+ PQDecoder codej(codes + j * code_size, pq.nbits);
165
152
 
166
153
  for (int l = 0; l < pq.M; l++) {
167
154
  accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
@@ -171,8 +158,7 @@ struct PQDistanceComputer: DistanceComputer {
171
158
  return accu;
172
159
  }
173
160
 
174
- explicit PQDistanceComputer(const IndexPQ& storage)
175
- : pq(storage.pq) {
161
+ explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
176
162
  precomputed_table.resize(pq.M * pq.ksub);
177
163
  nb = storage.ntotal;
178
164
  d = storage.d;
@@ -187,21 +173,18 @@ struct PQDistanceComputer: DistanceComputer {
187
173
  ndis = 0;
188
174
  }
189
175
 
190
- void set_query(const float *x) override {
176
+ void set_query(const float* x) override {
191
177
  if (metric == METRIC_L2) {
192
178
  pq.compute_distance_table(x, precomputed_table.data());
193
179
  } else {
194
180
  pq.compute_inner_prod_table(x, precomputed_table.data());
195
181
  }
196
-
197
182
  }
198
183
  };
199
184
 
185
+ } // namespace
200
186
 
201
- } // namespace
202
-
203
-
204
- DistanceComputer * IndexPQ::get_distance_computer() const {
187
+ DistanceComputer* IndexPQ::get_distance_computer() const {
205
188
  if (pq.nbits == 8) {
206
189
  return new PQDistanceComputer<PQDecoder8>(*this);
207
190
  } else if (pq.nbits == 16) {
@@ -211,142 +194,142 @@ DistanceComputer * IndexPQ::get_distance_computer() const {
211
194
  }
212
195
  }
213
196
 
214
-
215
197
  /*****************************************
216
198
  * IndexPQ polysemous search routines
217
199
  ******************************************/
218
200
 
201
+ void IndexPQ::search(
202
+ idx_t n,
203
+ const float* x,
204
+ idx_t k,
205
+ float* distances,
206
+ idx_t* labels) const {
207
+ FAISS_THROW_IF_NOT(k > 0);
219
208
 
220
-
221
-
222
-
223
- void IndexPQ::search (idx_t n, const float *x, idx_t k,
224
- float *distances, idx_t *labels) const
225
- {
226
- FAISS_THROW_IF_NOT (is_trained);
227
- if (search_type == ST_PQ) { // Simple PQ search
209
+ FAISS_THROW_IF_NOT(is_trained);
210
+ if (search_type == ST_PQ) { // Simple PQ search
228
211
 
229
212
  if (metric_type == METRIC_L2) {
230
213
  float_maxheap_array_t res = {
231
- size_t(n), size_t(k), labels, distances };
232
- pq.search (x, n, codes.data(), ntotal, &res, true);
214
+ size_t(n), size_t(k), labels, distances};
215
+ pq.search(x, n, codes.data(), ntotal, &res, true);
233
216
  } else {
234
217
  float_minheap_array_t res = {
235
- size_t(n), size_t(k), labels, distances };
236
- pq.search_ip (x, n, codes.data(), ntotal, &res, true);
218
+ size_t(n), size_t(k), labels, distances};
219
+ pq.search_ip(x, n, codes.data(), ntotal, &res, true);
237
220
  }
238
221
  indexPQ_stats.nq += n;
239
222
  indexPQ_stats.ncode += n * ntotal;
240
223
 
241
- } else if (search_type == ST_polysemous ||
242
- search_type == ST_polysemous_generalize) {
243
-
244
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
224
+ } else if (
225
+ search_type == ST_polysemous ||
226
+ search_type == ST_polysemous_generalize) {
227
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
245
228
 
246
- search_core_polysemous (n, x, k, distances, labels);
229
+ search_core_polysemous(n, x, k, distances, labels);
247
230
 
248
231
  } else { // code-to-code distances
249
232
 
250
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
251
- ScopeDeleter<uint8_t> del (q_codes);
252
-
233
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
234
+ ScopeDeleter<uint8_t> del(q_codes);
253
235
 
254
236
  if (!encode_signs) {
255
- pq.compute_codes (x, q_codes, n);
237
+ pq.compute_codes(x, q_codes, n);
256
238
  } else {
257
- FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
258
- memset (q_codes, 0, n * pq.code_size);
239
+ FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
240
+ memset(q_codes, 0, n * pq.code_size);
259
241
  for (size_t i = 0; i < n; i++) {
260
- const float *xi = x + i * d;
261
- uint8_t *code = q_codes + i * pq.code_size;
242
+ const float* xi = x + i * d;
243
+ uint8_t* code = q_codes + i * pq.code_size;
262
244
  for (int j = 0; j < d; j++)
263
- if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
245
+ if (xi[j] > 0)
246
+ code[j >> 3] |= 1 << (j & 7);
264
247
  }
265
248
  }
266
249
 
267
- if (search_type == ST_SDC) {
268
-
250
+ if (search_type == ST_SDC) {
269
251
  float_maxheap_array_t res = {
270
- size_t(n), size_t(k), labels, distances};
252
+ size_t(n), size_t(k), labels, distances};
271
253
 
272
- pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
254
+ pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
273
255
 
274
256
  } else {
275
- int * idistances = new int [n * k];
276
- ScopeDeleter<int> del (idistances);
257
+ int* idistances = new int[n * k];
258
+ ScopeDeleter<int> del(idistances);
277
259
 
278
260
  int_maxheap_array_t res = {
279
- size_t (n), size_t (k), labels, idistances};
261
+ size_t(n), size_t(k), labels, idistances};
280
262
 
281
263
  if (search_type == ST_HE) {
282
-
283
- hammings_knn_hc (&res, q_codes, codes.data(),
284
- ntotal, pq.code_size, true);
264
+ hammings_knn_hc(
265
+ &res,
266
+ q_codes,
267
+ codes.data(),
268
+ ntotal,
269
+ pq.code_size,
270
+ true);
285
271
 
286
272
  } else if (search_type == ST_generalized_HE) {
287
-
288
- generalized_hammings_knn_hc (&res, q_codes, codes.data(),
289
- ntotal, pq.code_size, true);
273
+ generalized_hammings_knn_hc(
274
+ &res,
275
+ q_codes,
276
+ codes.data(),
277
+ ntotal,
278
+ pq.code_size,
279
+ true);
290
280
  }
291
281
 
292
282
  // convert distances to floats
293
283
  for (int i = 0; i < k * n; i++)
294
284
  distances[i] = idistances[i];
295
-
296
285
  }
297
286
 
298
-
299
287
  indexPQ_stats.nq += n;
300
288
  indexPQ_stats.ncode += n * ntotal;
301
289
  }
302
290
  }
303
291
 
304
-
305
-
306
-
307
-
308
- void IndexPQStats::reset()
309
- {
292
+ void IndexPQStats::reset() {
310
293
  nq = ncode = n_hamming_pass = 0;
311
294
  }
312
295
 
313
296
  IndexPQStats indexPQ_stats;
314
297
 
315
-
316
298
  template <class HammingComputer>
317
- static size_t polysemous_inner_loop (
318
- const IndexPQ & index,
319
- const float *dis_table_qi, const uint8_t *q_code,
320
- size_t k, float *heap_dis, int64_t *heap_ids)
321
- {
322
-
299
+ static size_t polysemous_inner_loop(
300
+ const IndexPQ& index,
301
+ const float* dis_table_qi,
302
+ const uint8_t* q_code,
303
+ size_t k,
304
+ float* heap_dis,
305
+ int64_t* heap_ids) {
323
306
  int M = index.pq.M;
324
307
  int code_size = index.pq.code_size;
325
308
  int ksub = index.pq.ksub;
326
309
  size_t ntotal = index.ntotal;
327
310
  int ht = index.polysemous_ht;
328
311
 
329
- const uint8_t *b_code = index.codes.data();
312
+ const uint8_t* b_code = index.codes.data();
330
313
 
331
314
  size_t n_pass_i = 0;
332
315
 
333
- HammingComputer hc (q_code, code_size);
316
+ HammingComputer hc(q_code, code_size);
334
317
 
335
318
  for (int64_t bi = 0; bi < ntotal; bi++) {
336
- int hd = hc.hamming (b_code);
319
+ int hd = hc.hamming(b_code);
337
320
 
338
321
  if (hd < ht) {
339
- n_pass_i ++;
322
+ n_pass_i++;
340
323
 
341
324
  float dis = 0;
342
- const float * dis_table = dis_table_qi;
325
+ const float* dis_table = dis_table_qi;
343
326
  for (int m = 0; m < M; m++) {
344
- dis += dis_table [b_code[m]];
327
+ dis += dis_table[b_code[m]];
345
328
  dis_table += ksub;
346
329
  }
347
330
 
348
331
  if (dis < heap_dis[0]) {
349
- maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
332
+ maxheap_replace_top(k, heap_dis, heap_ids, dis, bi);
350
333
  }
351
334
  }
352
335
  b_code += code_size;
@@ -354,201 +337,204 @@ static size_t polysemous_inner_loop (
354
337
  return n_pass_i;
355
338
  }
356
339
 
340
+ void IndexPQ::search_core_polysemous(
341
+ idx_t n,
342
+ const float* x,
343
+ idx_t k,
344
+ float* distances,
345
+ idx_t* labels) const {
346
+ FAISS_THROW_IF_NOT(k > 0);
357
347
 
358
- void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
359
- float *distances, idx_t *labels) const
360
- {
361
- FAISS_THROW_IF_NOT (pq.nbits == 8);
348
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
362
349
 
363
350
  // PQ distance tables
364
- float * dis_tables = new float [n * pq.ksub * pq.M];
365
- ScopeDeleter<float> del (dis_tables);
366
- pq.compute_distance_tables (n, x, dis_tables);
351
+ float* dis_tables = new float[n * pq.ksub * pq.M];
352
+ ScopeDeleter<float> del(dis_tables);
353
+ pq.compute_distance_tables(n, x, dis_tables);
367
354
 
368
355
  // Hamming embedding queries
369
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
370
- ScopeDeleter<uint8_t> del2 (q_codes);
356
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
357
+ ScopeDeleter<uint8_t> del2(q_codes);
371
358
 
372
359
  if (false) {
373
- pq.compute_codes (x, q_codes, n);
360
+ pq.compute_codes(x, q_codes, n);
374
361
  } else {
375
362
  #pragma omp parallel for
376
363
  for (idx_t qi = 0; qi < n; qi++) {
377
- pq.compute_code_from_distance_table
378
- (dis_tables + qi * pq.M * pq.ksub,
379
- q_codes + qi * pq.code_size);
364
+ pq.compute_code_from_distance_table(
365
+ dis_tables + qi * pq.M * pq.ksub,
366
+ q_codes + qi * pq.code_size);
380
367
  }
381
368
  }
382
369
 
383
370
  size_t n_pass = 0;
384
371
 
385
- #pragma omp parallel for reduction (+: n_pass)
372
+ #pragma omp parallel for reduction(+ : n_pass)
386
373
  for (idx_t qi = 0; qi < n; qi++) {
387
- const uint8_t * q_code = q_codes + qi * pq.code_size;
374
+ const uint8_t* q_code = q_codes + qi * pq.code_size;
388
375
 
389
- const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
376
+ const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
390
377
 
391
- int64_t * heap_ids = labels + qi * k;
392
- float *heap_dis = distances + qi * k;
393
- maxheap_heapify (k, heap_dis, heap_ids);
378
+ int64_t* heap_ids = labels + qi * k;
379
+ float* heap_dis = distances + qi * k;
380
+ maxheap_heapify(k, heap_dis, heap_ids);
394
381
 
395
382
  if (search_type == ST_polysemous) {
396
-
397
383
  switch (pq.code_size) {
398
- case 4:
399
- n_pass += polysemous_inner_loop<HammingComputer4>
400
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
401
- break;
402
- case 8:
403
- n_pass += polysemous_inner_loop<HammingComputer8>
404
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
405
- break;
406
- case 16:
407
- n_pass += polysemous_inner_loop<HammingComputer16>
408
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
409
- break;
410
- case 32:
411
- n_pass += polysemous_inner_loop<HammingComputer32>
412
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
413
- break;
414
- case 20:
415
- n_pass += polysemous_inner_loop<HammingComputer20>
416
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
417
- break;
418
- default:
419
- if (pq.code_size % 8 == 0) {
420
- n_pass += polysemous_inner_loop<HammingComputerM8>
421
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
422
- } else if (pq.code_size % 4 == 0) {
423
- n_pass += polysemous_inner_loop<HammingComputerM4>
424
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
- } else {
426
- FAISS_THROW_FMT(
427
- "code size %zd not supported for polysemous",
428
- pq.code_size);
429
- }
430
- break;
384
+ case 4:
385
+ n_pass += polysemous_inner_loop<HammingComputer4>(
386
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
387
+ break;
388
+ case 8:
389
+ n_pass += polysemous_inner_loop<HammingComputer8>(
390
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
391
+ break;
392
+ case 16:
393
+ n_pass += polysemous_inner_loop<HammingComputer16>(
394
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
395
+ break;
396
+ case 32:
397
+ n_pass += polysemous_inner_loop<HammingComputer32>(
398
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
399
+ break;
400
+ case 20:
401
+ n_pass += polysemous_inner_loop<HammingComputer20>(
402
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
403
+ break;
404
+ default:
405
+ if (pq.code_size % 4 == 0) {
406
+ n_pass += polysemous_inner_loop<HammingComputerDefault>(
407
+ *this,
408
+ dis_table_qi,
409
+ q_code,
410
+ k,
411
+ heap_dis,
412
+ heap_ids);
413
+ } else {
414
+ FAISS_THROW_FMT(
415
+ "code size %zd not supported for polysemous",
416
+ pq.code_size);
417
+ }
418
+ break;
431
419
  }
432
420
  } else {
433
421
  switch (pq.code_size) {
434
- case 8:
435
- n_pass += polysemous_inner_loop<GenHammingComputer8>
436
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
437
- break;
438
- case 16:
439
- n_pass += polysemous_inner_loop<GenHammingComputer16>
440
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
441
- break;
442
- case 32:
443
- n_pass += polysemous_inner_loop<GenHammingComputer32>
444
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
445
- break;
446
- default:
447
- if (pq.code_size % 8 == 0) {
448
- n_pass += polysemous_inner_loop<GenHammingComputerM8>
449
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
450
- } else {
451
- FAISS_THROW_FMT(
452
- "code size %zd not supported for polysemous",
453
- pq.code_size);
454
- }
455
- break;
422
+ case 8:
423
+ n_pass += polysemous_inner_loop<GenHammingComputer8>(
424
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
+ break;
426
+ case 16:
427
+ n_pass += polysemous_inner_loop<GenHammingComputer16>(
428
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
429
+ break;
430
+ case 32:
431
+ n_pass += polysemous_inner_loop<GenHammingComputer32>(
432
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
433
+ break;
434
+ default:
435
+ if (pq.code_size % 8 == 0) {
436
+ n_pass += polysemous_inner_loop<GenHammingComputerM8>(
437
+ *this,
438
+ dis_table_qi,
439
+ q_code,
440
+ k,
441
+ heap_dis,
442
+ heap_ids);
443
+ } else {
444
+ FAISS_THROW_FMT(
445
+ "code size %zd not supported for polysemous",
446
+ pq.code_size);
447
+ }
448
+ break;
456
449
  }
457
450
  }
458
- maxheap_reorder (k, heap_dis, heap_ids);
451
+ maxheap_reorder(k, heap_dis, heap_ids);
459
452
  }
460
453
 
461
454
  indexPQ_stats.nq += n;
462
455
  indexPQ_stats.ncode += n * ntotal;
463
456
  indexPQ_stats.n_hamming_pass += n_pass;
464
-
465
-
466
457
  }
467
458
 
468
-
469
459
  /* The standalone codec interface (just remaps to the PQ functions) */
470
- size_t IndexPQ::sa_code_size () const
471
- {
460
+ size_t IndexPQ::sa_code_size() const {
472
461
  return pq.code_size;
473
462
  }
474
463
 
475
- void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
476
- {
477
- pq.compute_codes (x, bytes, n);
464
+ void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
465
+ pq.compute_codes(x, bytes, n);
478
466
  }
479
467
 
480
- void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
481
- {
482
- pq.decode (bytes, x, n);
468
+ void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
469
+ pq.decode(bytes, x, n);
483
470
  }
484
471
 
485
-
486
-
487
-
488
472
  /*****************************************
489
473
  * Stats of IndexPQ codes
490
474
  ******************************************/
491
475
 
476
+ void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
477
+ const {
478
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
479
+ ScopeDeleter<uint8_t> del(q_codes);
492
480
 
481
+ pq.compute_codes(x, q_codes, n);
493
482
 
494
-
495
- void IndexPQ::hamming_distance_table (idx_t n, const float *x,
496
- int32_t *dis) const
497
- {
498
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
499
- ScopeDeleter<uint8_t> del (q_codes);
500
-
501
- pq.compute_codes (x, q_codes, n);
502
-
503
- hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
483
+ hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
504
484
  }
505
485
 
506
-
507
- void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
508
- idx_t nb, const float *xb,
509
- int64_t *hist)
510
- {
511
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
512
- FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
513
- FAISS_THROW_IF_NOT (pq.nbits == 8);
486
+ void IndexPQ::hamming_distance_histogram(
487
+ idx_t n,
488
+ const float* x,
489
+ idx_t nb,
490
+ const float* xb,
491
+ int64_t* hist) {
492
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
493
+ FAISS_THROW_IF_NOT(pq.code_size % 8 == 0);
494
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
514
495
 
515
496
  // Hamming embedding queries
516
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
517
- ScopeDeleter <uint8_t> del (q_codes);
518
- pq.compute_codes (x, q_codes, n);
497
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
498
+ ScopeDeleter<uint8_t> del(q_codes);
499
+ pq.compute_codes(x, q_codes, n);
519
500
 
520
- uint8_t * b_codes ;
521
- ScopeDeleter <uint8_t> del_b_codes;
501
+ uint8_t* b_codes;
502
+ ScopeDeleter<uint8_t> del_b_codes;
522
503
 
523
504
  if (xb) {
524
- b_codes = new uint8_t [nb * pq.code_size];
525
- del_b_codes.set (b_codes);
526
- pq.compute_codes (xb, b_codes, nb);
505
+ b_codes = new uint8_t[nb * pq.code_size];
506
+ del_b_codes.set(b_codes);
507
+ pq.compute_codes(xb, b_codes, nb);
527
508
  } else {
528
509
  nb = ntotal;
529
510
  b_codes = codes.data();
530
511
  }
531
512
  int nbits = pq.M * pq.nbits;
532
- memset (hist, 0, sizeof(*hist) * (nbits + 1));
513
+ memset(hist, 0, sizeof(*hist) * (nbits + 1));
533
514
  size_t bs = 256;
534
515
 
535
516
  #pragma omp parallel
536
517
  {
537
- std::vector<int64_t> histi (nbits + 1);
538
- hamdis_t *distances = new hamdis_t [nb * bs];
539
- ScopeDeleter<hamdis_t> del (distances);
518
+ std::vector<int64_t> histi(nbits + 1);
519
+ hamdis_t* distances = new hamdis_t[nb * bs];
520
+ ScopeDeleter<hamdis_t> del(distances);
540
521
  #pragma omp for
541
522
  for (idx_t q0 = 0; q0 < n; q0 += bs) {
542
523
  // printf ("dis stats: %zd/%zd\n", q0, n);
543
524
  size_t q1 = q0 + bs;
544
- if (q1 > n) q1 = n;
525
+ if (q1 > n)
526
+ q1 = n;
545
527
 
546
- hammings (q_codes + q0 * pq.code_size, b_codes,
547
- q1 - q0, nb,
548
- pq.code_size, distances);
528
+ hammings(
529
+ q_codes + q0 * pq.code_size,
530
+ b_codes,
531
+ q1 - q0,
532
+ nb,
533
+ pq.code_size,
534
+ distances);
549
535
 
550
536
  for (size_t i = 0; i < nb * (q1 - q0); i++)
551
- histi [distances [i]]++;
537
+ histi[distances[i]]++;
552
538
  }
553
539
  #pragma omp critical
554
540
  {
@@ -556,28 +542,8 @@ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
556
542
  hist[i] += histi[i];
557
543
  }
558
544
  }
559
-
560
545
  }
561
546
 
562
-
563
-
564
-
565
-
566
-
567
-
568
-
569
-
570
-
571
-
572
-
573
-
574
-
575
-
576
-
577
-
578
-
579
-
580
-
581
547
  /*****************************************
582
548
  * MultiIndexQuantizer
583
549
  ******************************************/
@@ -586,90 +552,87 @@ namespace {
586
552
 
587
553
  template <typename T>
588
554
  struct PreSortedArray {
589
-
590
- const T * x;
555
+ const T* x;
591
556
  int N;
592
557
 
593
- explicit PreSortedArray (int N): N(N) {
594
- }
595
- void init (const T*x) {
558
+ explicit PreSortedArray(int N) : N(N) {}
559
+ void init(const T* x) {
596
560
  this->x = x;
597
561
  }
598
562
  // get smallest value
599
- T get_0 () {
563
+ T get_0() {
600
564
  return x[0];
601
565
  }
602
566
 
603
567
  // get delta between n-smallest and n-1 -smallest
604
- T get_diff (int n) {
568
+ T get_diff(int n) {
605
569
  return x[n] - x[n - 1];
606
570
  }
607
571
 
608
572
  // remap orders counted from smallest to indices in array
609
- int get_ord (int n) {
573
+ int get_ord(int n) {
610
574
  return n;
611
575
  }
612
-
613
576
  };
614
577
 
615
578
  template <typename T>
616
579
  struct ArgSort {
617
- const T * x;
618
- bool operator() (size_t i, size_t j) {
580
+ const T* x;
581
+ bool operator()(size_t i, size_t j) {
619
582
  return x[i] < x[j];
620
583
  }
621
584
  };
622
585
 
623
-
624
586
  /** Array that maintains a permutation of its elements so that the
625
587
  * array's elements are sorted
626
588
  */
627
589
  template <typename T>
628
590
  struct SortedArray {
629
- const T * x;
591
+ const T* x;
630
592
  int N;
631
593
  std::vector<int> perm;
632
594
 
633
- explicit SortedArray (int N) {
595
+ explicit SortedArray(int N) {
634
596
  this->N = N;
635
- perm.resize (N);
597
+ perm.resize(N);
636
598
  }
637
599
 
638
- void init (const T*x) {
600
+ void init(const T* x) {
639
601
  this->x = x;
640
602
  for (int n = 0; n < N; n++)
641
603
  perm[n] = n;
642
- ArgSort<T> cmp = {x };
643
- std::sort (perm.begin(), perm.end(), cmp);
604
+ ArgSort<T> cmp = {x};
605
+ std::sort(perm.begin(), perm.end(), cmp);
644
606
  }
645
607
 
646
608
  // get smallest value
647
- T get_0 () {
609
+ T get_0() {
648
610
  return x[perm[0]];
649
611
  }
650
612
 
651
613
  // get delta between n-smallest and n-1 -smallest
652
- T get_diff (int n) {
614
+ T get_diff(int n) {
653
615
  return x[perm[n]] - x[perm[n - 1]];
654
616
  }
655
617
 
656
618
  // remap orders counted from smallest to indices in array
657
- int get_ord (int n) {
619
+ int get_ord(int n) {
658
620
  return perm[n];
659
621
  }
660
622
  };
661
623
 
662
-
663
-
664
624
  /** Array has n values. Sort the k first ones and copy the other ones
665
625
  * into elements k..n-1
666
626
  */
667
627
  template <class C>
668
- void partial_sort (int k, int n,
669
- const typename C::T * vals, typename C::TI * perm) {
628
+ void partial_sort(
629
+ int k,
630
+ int n,
631
+ const typename C::T* vals,
632
+ typename C::TI* perm) {
670
633
  // insert first k elts in heap
671
634
  for (int i = 1; i < k; i++) {
672
- indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
635
+ indirect_heap_push<C>(i + 1, vals, perm, perm[i]);
673
636
  }
674
637
 
675
638
  // insert next n - k elts in heap
@@ -678,8 +641,8 @@ void partial_sort (int k, int n,
678
641
  typename C::TI top = perm[0];
679
642
 
680
643
  if (C::cmp(vals[top], vals[id])) {
681
- indirect_heap_pop<C> (k, vals, perm);
682
- indirect_heap_push<C> (k, vals, perm, id);
644
+ indirect_heap_pop<C>(k, vals, perm);
645
+ indirect_heap_push<C>(k, vals, perm, id);
683
646
  perm[i] = top;
684
647
  } else {
685
648
  // nothing, elt at i is good where it is.
@@ -689,7 +652,7 @@ void partial_sort (int k, int n,
689
652
  // order the k first elements in heap
690
653
  for (int i = k - 1; i > 0; i--) {
691
654
  typename C::TI top = perm[0];
692
- indirect_heap_pop<C> (i + 1, vals, perm);
655
+ indirect_heap_pop<C>(i + 1, vals, perm);
693
656
  perm[i] = top;
694
657
  }
695
658
  }
@@ -697,69 +660,67 @@ void partial_sort (int k, int n,
697
660
  /** same as SortedArray, but only the k first elements are sorted */
698
661
  template <typename T>
699
662
  struct SemiSortedArray {
700
- const T * x;
663
+ const T* x;
701
664
  int N;
702
665
 
703
666
  // type of the heap: CMax = sort ascending
704
667
  typedef CMax<T, int> HC;
705
668
  std::vector<int> perm;
706
669
 
707
- int k; // k elements are sorted
670
+ int k; // k elements are sorted
708
671
 
709
672
  int initial_k, k_factor;
710
673
 
711
- explicit SemiSortedArray (int N) {
674
+ explicit SemiSortedArray(int N) {
712
675
  this->N = N;
713
- perm.resize (N);
714
- perm.resize (N);
676
+ perm.resize(N);
677
+ perm.resize(N);
715
678
  initial_k = 3;
716
679
  k_factor = 4;
717
680
  }
718
681
 
719
- void init (const T*x) {
682
+ void init(const T* x) {
720
683
  this->x = x;
721
684
  for (int n = 0; n < N; n++)
722
685
  perm[n] = n;
723
686
  k = 0;
724
- grow (initial_k);
687
+ grow(initial_k);
725
688
  }
726
689
 
727
690
  /// grow the sorted part of the array to size next_k
728
- void grow (int next_k) {
691
+ void grow(int next_k) {
729
692
  if (next_k < N) {
730
- partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
693
+ partial_sort<HC>(next_k - k, N - k, x, &perm[k]);
731
694
  k = next_k;
732
695
  } else { // full sort of remainder of array
733
- ArgSort<T> cmp = {x };
734
- std::sort (perm.begin() + k, perm.end(), cmp);
696
+ ArgSort<T> cmp = {x};
697
+ std::sort(perm.begin() + k, perm.end(), cmp);
735
698
  k = N;
736
699
  }
737
700
  }
738
701
 
739
702
  // get smallest value
740
- T get_0 () {
703
+ T get_0() {
741
704
  return x[perm[0]];
742
705
  }
743
706
 
744
707
  // get delta between n-smallest and n-1 -smallest
745
- T get_diff (int n) {
708
+ T get_diff(int n) {
746
709
  if (n >= k) {
747
710
  // want to keep powers of 2 - 1
748
711
  int next_k = (k + 1) * k_factor - 1;
749
- grow (next_k);
712
+ grow(next_k);
750
713
  }
751
714
  return x[perm[n]] - x[perm[n - 1]];
752
715
  }
753
716
 
754
717
  // remap orders counted from smallest to indices in array
755
- int get_ord (int n) {
756
- assert (n < k);
718
+ int get_ord(int n) {
719
+ assert(n < k);
757
720
  return perm[n];
758
721
  }
759
722
  };
760
723
 
761
-
762
-
763
724
  /*****************************************
764
725
  * Find the k smallest sums of M terms, where each term is taken in a
765
726
  * table x of n values.
@@ -779,19 +740,19 @@ struct SemiSortedArray {
779
740
  * occasionally several t's are returned.
780
741
  *
781
742
  * @param x size M * n, values to add up
782
- * @parms k nb of results to retrieve
743
+ * @param k nb of results to retrieve
783
744
  * @param M nb of terms
784
745
  * @param n nb of distinct values
785
746
  * @param sums output, size k, sorted
786
- * @prarm terms output, size k, with encoding as above
747
+ * @param terms output, size k, with encoding as above
787
748
  *
788
749
  ******************************************/
789
750
  template <typename T, class SSA, bool use_seen>
790
751
  struct MinSumK {
791
- int K; ///< nb of sums to return
792
- int M; ///< nb of elements to sum up
752
+ int K; ///< nb of sums to return
753
+ int M; ///< nb of elements to sum up
793
754
  int nbit; ///< nb of bits to encode one entry
794
- int N; ///< nb of possible elements for each of the M terms
755
+ int N; ///< nb of possible elements for each of the M terms
795
756
 
796
757
  /** the heap.
797
758
  * We use a heap to maintain a queue of sums, with the associated
@@ -799,21 +760,20 @@ struct MinSumK {
799
760
  */
800
761
  typedef CMin<T, int64_t> HC;
801
762
  size_t heap_capacity, heap_size;
802
- T *bh_val;
803
- int64_t *bh_ids;
763
+ T* bh_val;
764
+ int64_t* bh_ids;
804
765
 
805
- std::vector <SSA> ssx;
766
+ std::vector<SSA> ssx;
806
767
 
807
768
  // all results get pushed several times. When there are ties, they
808
769
  // are popped interleaved with others, so it is not easy to
809
770
  // identify them. Therefore, this bit array just marks elements
810
771
  // that were seen before.
811
- std::vector <uint8_t> seen;
772
+ std::vector<uint8_t> seen;
812
773
 
813
- MinSumK (int K, int M, int nbit, int N):
814
- K(K), M(M), nbit(nbit), N(N) {
774
+ MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) {
815
775
  heap_capacity = K * M;
816
- assert (N <= (1 << nbit));
776
+ assert(N <= (1 << nbit));
817
777
 
818
778
  // we'll do k steps, each step pushes at most M vals
819
779
  bh_val = new T[heap_capacity];
@@ -821,29 +781,27 @@ struct MinSumK {
821
781
 
822
782
  if (use_seen) {
823
783
  int64_t n_ids = weight(M);
824
- seen.resize ((n_ids + 7) / 8);
784
+ seen.resize((n_ids + 7) / 8);
825
785
  }
826
786
 
827
787
  for (int m = 0; m < M; m++)
828
- ssx.push_back (SSA(N));
829
-
788
+ ssx.push_back(SSA(N));
830
789
  }
831
790
 
832
- int64_t weight (int i) {
791
+ int64_t weight(int i) {
833
792
  return 1 << (i * nbit);
834
793
  }
835
794
 
836
- bool is_seen (int64_t i) {
795
+ bool is_seen(int64_t i) {
837
796
  return (seen[i >> 3] >> (i & 7)) & 1;
838
797
  }
839
798
 
840
- void mark_seen (int64_t i) {
799
+ void mark_seen(int64_t i) {
841
800
  if (use_seen)
842
- seen [i >> 3] |= 1 << (i & 7);
801
+ seen[i >> 3] |= 1 << (i & 7);
843
802
  }
844
803
 
845
- void run (const T *x, int64_t ldx,
846
- T * sums, int64_t * terms) {
804
+ void run(const T* x, int64_t ldx, T* sums, int64_t* terms) {
847
805
  heap_size = 0;
848
806
 
849
807
  for (int m = 0; m < M; m++) {
@@ -854,38 +812,41 @@ struct MinSumK {
854
812
  { // initial result: take min for all elements
855
813
  T sum = 0;
856
814
  terms[0] = 0;
857
- mark_seen (0);
815
+ mark_seen(0);
858
816
  for (int m = 0; m < M; m++) {
859
817
  sum += ssx[m].get_0();
860
818
  }
861
819
  sums[0] = sum;
862
820
  for (int m = 0; m < M; m++) {
863
- heap_push<HC> (++heap_size, bh_val, bh_ids,
864
- sum + ssx[m].get_diff(1),
865
- weight(m));
821
+ heap_push<HC>(
822
+ ++heap_size,
823
+ bh_val,
824
+ bh_ids,
825
+ sum + ssx[m].get_diff(1),
826
+ weight(m));
866
827
  }
867
828
  }
868
829
 
869
830
  for (int k = 1; k < K; k++) {
870
831
  // pop smallest value from heap
871
- if (use_seen) {// skip already seen elements
872
- while (is_seen (bh_ids[0])) {
873
- assert (heap_size > 0);
874
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
832
+ if (use_seen) { // skip already seen elements
833
+ while (is_seen(bh_ids[0])) {
834
+ assert(heap_size > 0);
835
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
875
836
  }
876
837
  }
877
- assert (heap_size > 0);
838
+ assert(heap_size > 0);
878
839
 
879
840
  T sum = sums[k] = bh_val[0];
880
841
  int64_t ti = terms[k] = bh_ids[0];
881
842
 
882
843
  if (use_seen) {
883
- mark_seen (ti);
884
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
844
+ mark_seen(ti);
845
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
885
846
  } else {
886
847
  do {
887
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
888
- } while (heap_size > 0 && bh_ids[0] == ti);
848
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
849
+ } while (heap_size > 0 && bh_ids[0] == ti);
889
850
  }
890
851
 
891
852
  // enqueue followers
@@ -893,9 +854,10 @@ struct MinSumK {
893
854
  for (int m = 0; m < M; m++) {
894
855
  int64_t n = ii & ((1L << nbit) - 1);
895
856
  ii >>= nbit;
896
- if (n + 1 >= N) continue;
857
+ if (n + 1 >= N)
858
+ continue;
897
859
 
898
- enqueue_follower (ti, m, n, sum);
860
+ enqueue_follower(ti, m, n, sum);
899
861
  }
900
862
  }
901
863
 
@@ -922,37 +884,29 @@ struct MinSumK {
922
884
  }
923
885
  }
924
886
 
925
-
926
- void enqueue_follower (int64_t ti, int m, int n, T sum) {
887
+ void enqueue_follower(int64_t ti, int m, int n, T sum) {
927
888
  T next_sum = sum + ssx[m].get_diff(n + 1);
928
889
  int64_t next_ti = ti + weight(m);
929
- heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
890
+ heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti);
930
891
  }
931
892
 
932
- ~MinSumK () {
933
- delete [] bh_ids;
934
- delete [] bh_val;
893
+ ~MinSumK() {
894
+ delete[] bh_ids;
895
+ delete[] bh_val;
935
896
  }
936
897
  };
937
898
 
938
899
  } // anonymous namespace
939
900
 
940
-
941
- MultiIndexQuantizer::MultiIndexQuantizer (int d,
942
- size_t M,
943
- size_t nbits):
944
- Index(d, METRIC_L2), pq(d, M, nbits)
945
- {
901
+ MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits)
902
+ : Index(d, METRIC_L2), pq(d, M, nbits) {
946
903
  is_trained = false;
947
904
  pq.verbose = verbose;
948
905
  }
949
906
 
950
-
951
-
952
- void MultiIndexQuantizer::train(idx_t n, const float *x)
953
- {
907
+ void MultiIndexQuantizer::train(idx_t n, const float* x) {
954
908
  pq.verbose = verbose;
955
- pq.train (n, x);
909
+ pq.train(n, x);
956
910
  is_trained = true;
957
911
  // count virtual elements in index
958
912
  ntotal = 1;
@@ -960,10 +914,16 @@ void MultiIndexQuantizer::train(idx_t n, const float *x)
960
914
  ntotal *= pq.ksub;
961
915
  }
962
916
 
917
+ void MultiIndexQuantizer::search(
918
+ idx_t n,
919
+ const float* x,
920
+ idx_t k,
921
+ float* distances,
922
+ idx_t* labels) const {
923
+ if (n == 0)
924
+ return;
963
925
 
964
- void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
965
- float *distances, idx_t *labels) const {
966
- if (n == 0) return;
926
+ FAISS_THROW_IF_NOT(k > 0);
967
927
 
968
928
  // the allocation just below can be severe...
969
929
  idx_t bs = 32768;
@@ -971,27 +931,28 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
971
931
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
972
932
  idx_t i1 = std::min(i0 + bs, n);
973
933
  if (verbose) {
974
- printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64 " / %" PRId64 "\n",
975
- i0, i1, n);
934
+ printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
935
+ " / %" PRId64 "\n",
936
+ i0,
937
+ i1,
938
+ n);
976
939
  }
977
- search (i1 - i0, x + i0 * d, k,
978
- distances + i0 * k,
979
- labels + i0 * k);
940
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
980
941
  }
981
942
  return;
982
943
  }
983
944
 
984
- float * dis_tables = new float [n * pq.ksub * pq.M];
985
- ScopeDeleter<float> del (dis_tables);
945
+ float* dis_tables = new float[n * pq.ksub * pq.M];
946
+ ScopeDeleter<float> del(dis_tables);
986
947
 
987
- pq.compute_distance_tables (n, x, dis_tables);
948
+ pq.compute_distance_tables(n, x, dis_tables);
988
949
 
989
950
  if (k == 1) {
990
951
  // simple version that just finds the min in each table
991
952
 
992
953
  #pragma omp parallel for
993
954
  for (int i = 0; i < n; i++) {
994
- const float * dis_table = dis_tables + i * pq.ksub * pq.M;
955
+ const float* dis_table = dis_tables + i * pq.ksub * pq.M;
995
956
  float dis = 0;
996
957
  idx_t label = 0;
997
958
 
@@ -1010,32 +971,27 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
1010
971
  dis_table += pq.ksub;
1011
972
  }
1012
973
 
1013
- distances [i] = dis;
1014
- labels [i] = label;
974
+ distances[i] = dis;
975
+ labels[i] = label;
1015
976
  }
1016
977
 
1017
-
1018
978
  } else {
1019
-
1020
- #pragma omp parallel if(n > 1)
979
+ #pragma omp parallel if (n > 1)
1021
980
  {
1022
- MinSumK <float, SemiSortedArray<float>, false>
1023
- msk(k, pq.M, pq.nbits, pq.ksub);
981
+ MinSumK<float, SemiSortedArray<float>, false> msk(
982
+ k, pq.M, pq.nbits, pq.ksub);
1024
983
  #pragma omp for
1025
984
  for (int i = 0; i < n; i++) {
1026
- msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
1027
- distances + i * k, labels + i * k);
1028
-
985
+ msk.run(dis_tables + i * pq.ksub * pq.M,
986
+ pq.ksub,
987
+ distances + i * k,
988
+ labels + i * k);
1029
989
  }
1030
990
  }
1031
991
  }
1032
-
1033
992
  }
1034
993
 
1035
-
1036
- void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1037
- {
1038
-
994
+ void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
1039
995
  int64_t jj = key;
1040
996
  for (int m = 0; m < pq.M; m++) {
1041
997
  int64_t n = jj & ((1L << pq.nbits) - 1);
@@ -1046,65 +1002,53 @@ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1046
1002
  }
1047
1003
 
1048
1004
  void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
1049
- FAISS_THROW_MSG(
1050
- "This index has virtual elements, "
1051
- "it does not support add");
1005
+ FAISS_THROW_MSG(
1006
+ "This index has virtual elements, "
1007
+ "it does not support add");
1052
1008
  }
1053
1009
 
1054
- void MultiIndexQuantizer::reset ()
1055
- {
1056
- FAISS_THROW_MSG ( "This index has virtual elements, "
1057
- "it does not support reset");
1010
+ void MultiIndexQuantizer::reset() {
1011
+ FAISS_THROW_MSG(
1012
+ "This index has virtual elements, "
1013
+ "it does not support reset");
1058
1014
  }
1059
1015
 
1060
-
1061
-
1062
-
1063
-
1064
-
1065
-
1066
-
1067
-
1068
-
1069
1016
  /*****************************************
1070
1017
  * MultiIndexQuantizer2
1071
1018
  ******************************************/
1072
1019
 
1073
-
1074
-
1075
- MultiIndexQuantizer2::MultiIndexQuantizer2 (
1076
- int d, size_t M, size_t nbits,
1077
- Index **indexes):
1078
- MultiIndexQuantizer (d, M, nbits)
1079
- {
1080
- assign_indexes.resize (M);
1020
+ MultiIndexQuantizer2::MultiIndexQuantizer2(
1021
+ int d,
1022
+ size_t M,
1023
+ size_t nbits,
1024
+ Index** indexes)
1025
+ : MultiIndexQuantizer(d, M, nbits) {
1026
+ assign_indexes.resize(M);
1081
1027
  for (int i = 0; i < M; i++) {
1082
1028
  FAISS_THROW_IF_NOT_MSG(
1083
- indexes[i]->d == pq.dsub,
1084
- "Provided sub-index has incorrect size");
1029
+ indexes[i]->d == pq.dsub,
1030
+ "Provided sub-index has incorrect size");
1085
1031
  assign_indexes[i] = indexes[i];
1086
1032
  }
1087
1033
  own_fields = false;
1088
1034
  }
1089
1035
 
1090
- MultiIndexQuantizer2::MultiIndexQuantizer2 (
1091
- int d, size_t nbits,
1092
- Index *assign_index_0,
1093
- Index *assign_index_1):
1094
- MultiIndexQuantizer (d, 2, nbits)
1095
- {
1036
+ MultiIndexQuantizer2::MultiIndexQuantizer2(
1037
+ int d,
1038
+ size_t nbits,
1039
+ Index* assign_index_0,
1040
+ Index* assign_index_1)
1041
+ : MultiIndexQuantizer(d, 2, nbits) {
1096
1042
  FAISS_THROW_IF_NOT_MSG(
1097
- assign_index_0->d == pq.dsub &&
1098
- assign_index_1->d == pq.dsub,
1043
+ assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub,
1099
1044
  "Provided sub-index has incorrect size");
1100
- assign_indexes.resize (2);
1101
- assign_indexes [0] = assign_index_0;
1102
- assign_indexes [1] = assign_index_1;
1045
+ assign_indexes.resize(2);
1046
+ assign_indexes[0] = assign_index_0;
1047
+ assign_indexes[1] = assign_index_1;
1103
1048
  own_fields = false;
1104
1049
  }
1105
1050
 
1106
- void MultiIndexQuantizer2::train(idx_t n, const float* x)
1107
- {
1051
+ void MultiIndexQuantizer2::train(idx_t n, const float* x) {
1108
1052
  MultiIndexQuantizer::train(n, x);
1109
1053
  // add centroids to sub-indexes
1110
1054
  for (int i = 0; i < pq.M; i++) {
@@ -1112,15 +1056,17 @@ void MultiIndexQuantizer2::train(idx_t n, const float* x)
1112
1056
  }
1113
1057
  }
1114
1058
 
1115
-
1116
1059
  void MultiIndexQuantizer2::search(
1117
- idx_t n, const float* x, idx_t K,
1118
- float* distances, idx_t* labels) const
1119
- {
1120
-
1121
- if (n == 0) return;
1060
+ idx_t n,
1061
+ const float* x,
1062
+ idx_t K,
1063
+ float* distances,
1064
+ idx_t* labels) const {
1065
+ if (n == 0)
1066
+ return;
1122
1067
 
1123
1068
  int k2 = std::min(K, int64_t(pq.ksub));
1069
+ FAISS_THROW_IF_NOT(k2);
1124
1070
 
1125
1071
  int64_t M = pq.M;
1126
1072
  int64_t dsub = pq.dsub, ksub = pq.ksub;
@@ -1131,8 +1077,8 @@ void MultiIndexQuantizer2::search(
1131
1077
  std::vector<float> xsub(n * dsub);
1132
1078
 
1133
1079
  for (int m = 0; m < M; m++) {
1134
- float *xdest = xsub.data();
1135
- const float *xsrc = x + m * dsub;
1080
+ float* xdest = xsub.data();
1081
+ const float* xsrc = x + m * dsub;
1136
1082
  for (int j = 0; j < n; j++) {
1137
1083
  memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1138
1084
  xsrc += d;
@@ -1140,14 +1086,12 @@ void MultiIndexQuantizer2::search(
1140
1086
  }
1141
1087
 
1142
1088
  assign_indexes[m]->search(
1143
- n, xsub.data(), k2,
1144
- &sub_dis[k2 * n * m],
1145
- &sub_ids[k2 * n * m]);
1089
+ n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]);
1146
1090
  }
1147
1091
 
1148
1092
  if (K == 1) {
1149
1093
  // simple version that just finds the min in each table
1150
- assert (k2 == 1);
1094
+ assert(k2 == 1);
1151
1095
 
1152
1096
  for (int i = 0; i < n; i++) {
1153
1097
  float dis = 0;
@@ -1159,30 +1103,28 @@ void MultiIndexQuantizer2::search(
1159
1103
  dis += vmin;
1160
1104
  label |= lmin << (m * pq.nbits);
1161
1105
  }
1162
- distances [i] = dis;
1163
- labels [i] = label;
1106
+ distances[i] = dis;
1107
+ labels[i] = label;
1164
1108
  }
1165
1109
 
1166
1110
  } else {
1167
-
1168
- #pragma omp parallel if(n > 1)
1111
+ #pragma omp parallel if (n > 1)
1169
1112
  {
1170
- MinSumK <float, PreSortedArray<float>, false>
1171
- msk(K, pq.M, pq.nbits, k2);
1113
+ MinSumK<float, PreSortedArray<float>, false> msk(
1114
+ K, pq.M, pq.nbits, k2);
1172
1115
  #pragma omp for
1173
1116
  for (int i = 0; i < n; i++) {
1174
- idx_t *li = labels + i * K;
1175
- msk.run (&sub_dis[i * k2], k2 * n,
1176
- distances + i * K, li);
1117
+ idx_t* li = labels + i * K;
1118
+ msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li);
1177
1119
 
1178
1120
  // remap ids
1179
1121
 
1180
- const idx_t *idmap0 = sub_ids.data() + i * k2;
1122
+ const idx_t* idmap0 = sub_ids.data() + i * k2;
1181
1123
  int64_t ld_idmap = k2 * n;
1182
1124
  int64_t mask1 = ksub - 1L;
1183
1125
 
1184
1126
  for (int k = 0; k < K; k++) {
1185
- const idx_t *idmap = idmap0;
1127
+ const idx_t* idmap = idmap0;
1186
1128
  int64_t vin = li[k];
1187
1129
  int64_t vout = 0;
1188
1130
  int bs = 0;
@@ -1200,5 +1142,4 @@ void MultiIndexQuantizer2::search(
1200
1142
  }
1201
1143
  }
1202
1144
 
1203
-
1204
1145
  } // namespace faiss