faiss 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +26 -1
  3. data/README.md +15 -3
  4. data/ext/faiss/ext.cpp +12 -308
  5. data/ext/faiss/extconf.rb +5 -2
  6. data/ext/faiss/index.cpp +189 -0
  7. data/ext/faiss/index_binary.cpp +75 -0
  8. data/ext/faiss/kmeans.cpp +40 -0
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +33 -0
  11. data/ext/faiss/product_quantizer.cpp +53 -0
  12. data/ext/faiss/utils.cpp +13 -0
  13. data/ext/faiss/utils.h +5 -0
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +31 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,85 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <vector>
13
+
14
+ #include <faiss/IndexFlat.h>
15
+ #include <faiss/IndexNNDescent.h>
16
+ #include <faiss/impl/NSG.h>
17
+ #include <faiss/utils/utils.h>
18
+
19
+ namespace faiss {
20
+
21
+ /** The NSG index is a normal random-access index with a NSG
22
+ * link structure built on top */
23
+
24
+ struct IndexNSG : Index {
25
+ /// the link strcuture
26
+ NSG nsg;
27
+
28
+ /// the sequential storage
29
+ bool own_fields;
30
+ Index* storage;
31
+
32
+ /// the index is built or not
33
+ bool is_built;
34
+
35
+ /// K of KNN graph for building
36
+ int GK;
37
+
38
+ /// indicate how to build a knn graph
39
+ /// - 0: build NSG with brute force search
40
+ /// - 1: build NSG with NNDescent
41
+ char build_type;
42
+
43
+ /// parameters for nndescent
44
+ int nndescent_S;
45
+ int nndescent_R;
46
+ int nndescent_L;
47
+ int nndescent_iter;
48
+
49
+ explicit IndexNSG(int d = 0, int R = 32, MetricType metric = METRIC_L2);
50
+ explicit IndexNSG(Index* storage, int R = 32);
51
+
52
+ ~IndexNSG() override;
53
+
54
+ void build(idx_t n, const float* x, idx_t* knn_graph, int GK);
55
+
56
+ void add(idx_t n, const float* x) override;
57
+
58
+ /// Trains the storage if needed
59
+ void train(idx_t n, const float* x) override;
60
+
61
+ /// entry point for search
62
+ void search(
63
+ idx_t n,
64
+ const float* x,
65
+ idx_t k,
66
+ float* distances,
67
+ idx_t* labels) const override;
68
+
69
+ void reconstruct(idx_t key, float* recons) const override;
70
+
71
+ void reset() override;
72
+
73
+ void check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const;
74
+ };
75
+
76
+ /** Flat index topped with with a NSG structure to access elements
77
+ * more efficiently.
78
+ */
79
+
80
+ struct IndexNSGFlat : IndexNSG {
81
+ IndexNSGFlat();
82
+ IndexNSGFlat(int d, int R, MetricType metric = METRIC_L2);
83
+ };
84
+
85
+ } // namespace faiss
@@ -10,15 +10,15 @@
10
10
  #include <faiss/IndexPQ.h>
11
11
 
12
12
  #include <cinttypes>
13
+ #include <cmath>
13
14
  #include <cstddef>
14
- #include <cstring>
15
15
  #include <cstdio>
16
- #include <cmath>
16
+ #include <cstring>
17
17
 
18
18
  #include <algorithm>
19
19
 
20
- #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/impl/AuxIndexStructures.h>
21
+ #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/utils/hamming.h>
23
23
 
24
24
  namespace faiss {
@@ -27,10 +27,8 @@ namespace faiss {
27
27
  * IndexPQ implementation
28
28
  ********************************************************/
29
29
 
30
-
31
- IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
32
- Index(d, metric), pq(d, M, nbits)
33
- {
30
+ IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
31
+ : Index(d, metric), pq(d, M, nbits) {
34
32
  is_trained = false;
35
33
  do_polysemous_training = false;
36
34
  polysemous_ht = nbits * M + 1;
@@ -38,8 +36,7 @@ IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
38
36
  encode_signs = false;
39
37
  }
40
38
 
41
- IndexPQ::IndexPQ ()
42
- {
39
+ IndexPQ::IndexPQ() {
43
40
  metric_type = METRIC_L2;
44
41
  is_trained = false;
45
42
  do_polysemous_training = false;
@@ -48,10 +45,8 @@ IndexPQ::IndexPQ ()
48
45
  encode_signs = false;
49
46
  }
50
47
 
51
-
52
- void IndexPQ::train (idx_t n, const float *x)
53
- {
54
- if (!do_polysemous_training) { // standard training
48
+ void IndexPQ::train(idx_t n, const float* x) {
49
+ if (!do_polysemous_training) { // standard training
55
50
  pq.train(n, x);
56
51
  } else {
57
52
  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
@@ -59,38 +54,38 @@ void IndexPQ::train (idx_t n, const float *x)
59
54
  if (ntrain_perm > n / 4)
60
55
  ntrain_perm = n / 4;
61
56
  if (verbose) {
62
- printf ("PQ training on %" PRId64 " points, remains %" PRId64 " points: "
63
- "training polysemous on %s\n",
64
- n - ntrain_perm, ntrain_perm,
65
- ntrain_perm == 0 ? "centroids" : "these");
57
+ printf("PQ training on %" PRId64 " points, remains %" PRId64
58
+ " points: "
59
+ "training polysemous on %s\n",
60
+ n - ntrain_perm,
61
+ ntrain_perm,
62
+ ntrain_perm == 0 ? "centroids" : "these");
66
63
  }
67
64
  pq.train(n - ntrain_perm, x);
68
65
 
69
- polysemous_training.optimize_pq_for_hamming (
70
- pq, ntrain_perm, x + (n - ntrain_perm) * d);
66
+ polysemous_training.optimize_pq_for_hamming(
67
+ pq, ntrain_perm, x + (n - ntrain_perm) * d);
71
68
  }
72
69
  is_trained = true;
73
70
  }
74
71
 
75
-
76
- void IndexPQ::add (idx_t n, const float *x)
77
- {
78
- FAISS_THROW_IF_NOT (is_trained);
79
- codes.resize ((n + ntotal) * pq.code_size);
80
- pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
72
+ void IndexPQ::add(idx_t n, const float* x) {
73
+ FAISS_THROW_IF_NOT(is_trained);
74
+ codes.resize((n + ntotal) * pq.code_size);
75
+ pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
81
76
  ntotal += n;
82
77
  }
83
78
 
84
-
85
- size_t IndexPQ::remove_ids (const IDSelector & sel)
86
- {
79
+ size_t IndexPQ::remove_ids(const IDSelector& sel) {
87
80
  idx_t j = 0;
88
81
  for (idx_t i = 0; i < ntotal; i++) {
89
- if (sel.is_member (i)) {
82
+ if (sel.is_member(i)) {
90
83
  // should be removed
91
84
  } else {
92
85
  if (i > j) {
93
- memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
86
+ memmove(&codes[pq.code_size * j],
87
+ &codes[pq.code_size * i],
88
+ pq.code_size);
94
89
  }
95
90
  j++;
96
91
  }
@@ -98,53 +93,46 @@ size_t IndexPQ::remove_ids (const IDSelector & sel)
98
93
  size_t nremove = ntotal - j;
99
94
  if (nremove > 0) {
100
95
  ntotal = j;
101
- codes.resize (ntotal * pq.code_size);
96
+ codes.resize(ntotal * pq.code_size);
102
97
  }
103
98
  return nremove;
104
99
  }
105
100
 
106
-
107
- void IndexPQ::reset()
108
- {
101
+ void IndexPQ::reset() {
109
102
  codes.clear();
110
103
  ntotal = 0;
111
104
  }
112
105
 
113
- void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
114
- {
115
- FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
106
+ void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
107
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
116
108
  for (idx_t i = 0; i < ni; i++) {
117
- const uint8_t * code = &codes[(i0 + i) * pq.code_size];
118
- pq.decode (code, recons + i * d);
109
+ const uint8_t* code = &codes[(i0 + i) * pq.code_size];
110
+ pq.decode(code, recons + i * d);
119
111
  }
120
112
  }
121
113
 
122
-
123
- void IndexPQ::reconstruct (idx_t key, float * recons) const
124
- {
125
- FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
126
- pq.decode (&codes[key * pq.code_size], recons);
114
+ void IndexPQ::reconstruct(idx_t key, float* recons) const {
115
+ FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
116
+ pq.decode(&codes[key * pq.code_size], recons);
127
117
  }
128
118
 
129
-
130
119
  namespace {
131
120
 
132
- template<class PQDecoder>
133
- struct PQDistanceComputer: DistanceComputer {
121
+ template <class PQDecoder>
122
+ struct PQDistanceComputer : DistanceComputer {
134
123
  size_t d;
135
124
  MetricType metric;
136
125
  Index::idx_t nb;
137
- const uint8_t *codes;
126
+ const uint8_t* codes;
138
127
  size_t code_size;
139
- const ProductQuantizer & pq;
140
- const float *sdc;
128
+ const ProductQuantizer& pq;
129
+ const float* sdc;
141
130
  std::vector<float> precomputed_table;
142
131
  size_t ndis;
143
132
 
144
- float operator () (idx_t i) override
145
- {
146
- const uint8_t *code = codes + i * code_size;
147
- const float *dt = precomputed_table.data();
133
+ float operator()(idx_t i) override {
134
+ const uint8_t* code = codes + i * code_size;
135
+ const float* dt = precomputed_table.data();
148
136
  PQDecoder decoder(code, pq.nbits);
149
137
  float accu = 0;
150
138
  for (int j = 0; j < pq.M; j++) {
@@ -155,13 +143,12 @@ struct PQDistanceComputer: DistanceComputer {
155
143
  return accu;
156
144
  }
157
145
 
158
- float symmetric_dis(idx_t i, idx_t j) override
159
- {
146
+ float symmetric_dis(idx_t i, idx_t j) override {
160
147
  FAISS_THROW_IF_NOT(sdc);
161
- const float * sdci = sdc;
148
+ const float* sdci = sdc;
162
149
  float accu = 0;
163
- PQDecoder codei (codes + i * code_size, pq.nbits);
164
- PQDecoder codej (codes + j * code_size, pq.nbits);
150
+ PQDecoder codei(codes + i * code_size, pq.nbits);
151
+ PQDecoder codej(codes + j * code_size, pq.nbits);
165
152
 
166
153
  for (int l = 0; l < pq.M; l++) {
167
154
  accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
@@ -171,8 +158,7 @@ struct PQDistanceComputer: DistanceComputer {
171
158
  return accu;
172
159
  }
173
160
 
174
- explicit PQDistanceComputer(const IndexPQ& storage)
175
- : pq(storage.pq) {
161
+ explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
176
162
  precomputed_table.resize(pq.M * pq.ksub);
177
163
  nb = storage.ntotal;
178
164
  d = storage.d;
@@ -187,21 +173,18 @@ struct PQDistanceComputer: DistanceComputer {
187
173
  ndis = 0;
188
174
  }
189
175
 
190
- void set_query(const float *x) override {
176
+ void set_query(const float* x) override {
191
177
  if (metric == METRIC_L2) {
192
178
  pq.compute_distance_table(x, precomputed_table.data());
193
179
  } else {
194
180
  pq.compute_inner_prod_table(x, precomputed_table.data());
195
181
  }
196
-
197
182
  }
198
183
  };
199
184
 
185
+ } // namespace
200
186
 
201
- } // namespace
202
-
203
-
204
- DistanceComputer * IndexPQ::get_distance_computer() const {
187
+ DistanceComputer* IndexPQ::get_distance_computer() const {
205
188
  if (pq.nbits == 8) {
206
189
  return new PQDistanceComputer<PQDecoder8>(*this);
207
190
  } else if (pq.nbits == 16) {
@@ -211,142 +194,142 @@ DistanceComputer * IndexPQ::get_distance_computer() const {
211
194
  }
212
195
  }
213
196
 
214
-
215
197
  /*****************************************
216
198
  * IndexPQ polysemous search routines
217
199
  ******************************************/
218
200
 
201
+ void IndexPQ::search(
202
+ idx_t n,
203
+ const float* x,
204
+ idx_t k,
205
+ float* distances,
206
+ idx_t* labels) const {
207
+ FAISS_THROW_IF_NOT(k > 0);
219
208
 
220
-
221
-
222
-
223
- void IndexPQ::search (idx_t n, const float *x, idx_t k,
224
- float *distances, idx_t *labels) const
225
- {
226
- FAISS_THROW_IF_NOT (is_trained);
227
- if (search_type == ST_PQ) { // Simple PQ search
209
+ FAISS_THROW_IF_NOT(is_trained);
210
+ if (search_type == ST_PQ) { // Simple PQ search
228
211
 
229
212
  if (metric_type == METRIC_L2) {
230
213
  float_maxheap_array_t res = {
231
- size_t(n), size_t(k), labels, distances };
232
- pq.search (x, n, codes.data(), ntotal, &res, true);
214
+ size_t(n), size_t(k), labels, distances};
215
+ pq.search(x, n, codes.data(), ntotal, &res, true);
233
216
  } else {
234
217
  float_minheap_array_t res = {
235
- size_t(n), size_t(k), labels, distances };
236
- pq.search_ip (x, n, codes.data(), ntotal, &res, true);
218
+ size_t(n), size_t(k), labels, distances};
219
+ pq.search_ip(x, n, codes.data(), ntotal, &res, true);
237
220
  }
238
221
  indexPQ_stats.nq += n;
239
222
  indexPQ_stats.ncode += n * ntotal;
240
223
 
241
- } else if (search_type == ST_polysemous ||
242
- search_type == ST_polysemous_generalize) {
243
-
244
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
224
+ } else if (
225
+ search_type == ST_polysemous ||
226
+ search_type == ST_polysemous_generalize) {
227
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
245
228
 
246
- search_core_polysemous (n, x, k, distances, labels);
229
+ search_core_polysemous(n, x, k, distances, labels);
247
230
 
248
231
  } else { // code-to-code distances
249
232
 
250
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
251
- ScopeDeleter<uint8_t> del (q_codes);
252
-
233
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
234
+ ScopeDeleter<uint8_t> del(q_codes);
253
235
 
254
236
  if (!encode_signs) {
255
- pq.compute_codes (x, q_codes, n);
237
+ pq.compute_codes(x, q_codes, n);
256
238
  } else {
257
- FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
258
- memset (q_codes, 0, n * pq.code_size);
239
+ FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
240
+ memset(q_codes, 0, n * pq.code_size);
259
241
  for (size_t i = 0; i < n; i++) {
260
- const float *xi = x + i * d;
261
- uint8_t *code = q_codes + i * pq.code_size;
242
+ const float* xi = x + i * d;
243
+ uint8_t* code = q_codes + i * pq.code_size;
262
244
  for (int j = 0; j < d; j++)
263
- if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
245
+ if (xi[j] > 0)
246
+ code[j >> 3] |= 1 << (j & 7);
264
247
  }
265
248
  }
266
249
 
267
- if (search_type == ST_SDC) {
268
-
250
+ if (search_type == ST_SDC) {
269
251
  float_maxheap_array_t res = {
270
- size_t(n), size_t(k), labels, distances};
252
+ size_t(n), size_t(k), labels, distances};
271
253
 
272
- pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
254
+ pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
273
255
 
274
256
  } else {
275
- int * idistances = new int [n * k];
276
- ScopeDeleter<int> del (idistances);
257
+ int* idistances = new int[n * k];
258
+ ScopeDeleter<int> del(idistances);
277
259
 
278
260
  int_maxheap_array_t res = {
279
- size_t (n), size_t (k), labels, idistances};
261
+ size_t(n), size_t(k), labels, idistances};
280
262
 
281
263
  if (search_type == ST_HE) {
282
-
283
- hammings_knn_hc (&res, q_codes, codes.data(),
284
- ntotal, pq.code_size, true);
264
+ hammings_knn_hc(
265
+ &res,
266
+ q_codes,
267
+ codes.data(),
268
+ ntotal,
269
+ pq.code_size,
270
+ true);
285
271
 
286
272
  } else if (search_type == ST_generalized_HE) {
287
-
288
- generalized_hammings_knn_hc (&res, q_codes, codes.data(),
289
- ntotal, pq.code_size, true);
273
+ generalized_hammings_knn_hc(
274
+ &res,
275
+ q_codes,
276
+ codes.data(),
277
+ ntotal,
278
+ pq.code_size,
279
+ true);
290
280
  }
291
281
 
292
282
  // convert distances to floats
293
283
  for (int i = 0; i < k * n; i++)
294
284
  distances[i] = idistances[i];
295
-
296
285
  }
297
286
 
298
-
299
287
  indexPQ_stats.nq += n;
300
288
  indexPQ_stats.ncode += n * ntotal;
301
289
  }
302
290
  }
303
291
 
304
-
305
-
306
-
307
-
308
- void IndexPQStats::reset()
309
- {
292
+ void IndexPQStats::reset() {
310
293
  nq = ncode = n_hamming_pass = 0;
311
294
  }
312
295
 
313
296
  IndexPQStats indexPQ_stats;
314
297
 
315
-
316
298
  template <class HammingComputer>
317
- static size_t polysemous_inner_loop (
318
- const IndexPQ & index,
319
- const float *dis_table_qi, const uint8_t *q_code,
320
- size_t k, float *heap_dis, int64_t *heap_ids)
321
- {
322
-
299
+ static size_t polysemous_inner_loop(
300
+ const IndexPQ& index,
301
+ const float* dis_table_qi,
302
+ const uint8_t* q_code,
303
+ size_t k,
304
+ float* heap_dis,
305
+ int64_t* heap_ids) {
323
306
  int M = index.pq.M;
324
307
  int code_size = index.pq.code_size;
325
308
  int ksub = index.pq.ksub;
326
309
  size_t ntotal = index.ntotal;
327
310
  int ht = index.polysemous_ht;
328
311
 
329
- const uint8_t *b_code = index.codes.data();
312
+ const uint8_t* b_code = index.codes.data();
330
313
 
331
314
  size_t n_pass_i = 0;
332
315
 
333
- HammingComputer hc (q_code, code_size);
316
+ HammingComputer hc(q_code, code_size);
334
317
 
335
318
  for (int64_t bi = 0; bi < ntotal; bi++) {
336
- int hd = hc.hamming (b_code);
319
+ int hd = hc.hamming(b_code);
337
320
 
338
321
  if (hd < ht) {
339
- n_pass_i ++;
322
+ n_pass_i++;
340
323
 
341
324
  float dis = 0;
342
- const float * dis_table = dis_table_qi;
325
+ const float* dis_table = dis_table_qi;
343
326
  for (int m = 0; m < M; m++) {
344
- dis += dis_table [b_code[m]];
327
+ dis += dis_table[b_code[m]];
345
328
  dis_table += ksub;
346
329
  }
347
330
 
348
331
  if (dis < heap_dis[0]) {
349
- maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
332
+ maxheap_replace_top(k, heap_dis, heap_ids, dis, bi);
350
333
  }
351
334
  }
352
335
  b_code += code_size;
@@ -354,201 +337,204 @@ static size_t polysemous_inner_loop (
354
337
  return n_pass_i;
355
338
  }
356
339
 
340
+ void IndexPQ::search_core_polysemous(
341
+ idx_t n,
342
+ const float* x,
343
+ idx_t k,
344
+ float* distances,
345
+ idx_t* labels) const {
346
+ FAISS_THROW_IF_NOT(k > 0);
357
347
 
358
- void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
359
- float *distances, idx_t *labels) const
360
- {
361
- FAISS_THROW_IF_NOT (pq.nbits == 8);
348
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
362
349
 
363
350
  // PQ distance tables
364
- float * dis_tables = new float [n * pq.ksub * pq.M];
365
- ScopeDeleter<float> del (dis_tables);
366
- pq.compute_distance_tables (n, x, dis_tables);
351
+ float* dis_tables = new float[n * pq.ksub * pq.M];
352
+ ScopeDeleter<float> del(dis_tables);
353
+ pq.compute_distance_tables(n, x, dis_tables);
367
354
 
368
355
  // Hamming embedding queries
369
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
370
- ScopeDeleter<uint8_t> del2 (q_codes);
356
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
357
+ ScopeDeleter<uint8_t> del2(q_codes);
371
358
 
372
359
  if (false) {
373
- pq.compute_codes (x, q_codes, n);
360
+ pq.compute_codes(x, q_codes, n);
374
361
  } else {
375
362
  #pragma omp parallel for
376
363
  for (idx_t qi = 0; qi < n; qi++) {
377
- pq.compute_code_from_distance_table
378
- (dis_tables + qi * pq.M * pq.ksub,
379
- q_codes + qi * pq.code_size);
364
+ pq.compute_code_from_distance_table(
365
+ dis_tables + qi * pq.M * pq.ksub,
366
+ q_codes + qi * pq.code_size);
380
367
  }
381
368
  }
382
369
 
383
370
  size_t n_pass = 0;
384
371
 
385
- #pragma omp parallel for reduction (+: n_pass)
372
+ #pragma omp parallel for reduction(+ : n_pass)
386
373
  for (idx_t qi = 0; qi < n; qi++) {
387
- const uint8_t * q_code = q_codes + qi * pq.code_size;
374
+ const uint8_t* q_code = q_codes + qi * pq.code_size;
388
375
 
389
- const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
376
+ const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
390
377
 
391
- int64_t * heap_ids = labels + qi * k;
392
- float *heap_dis = distances + qi * k;
393
- maxheap_heapify (k, heap_dis, heap_ids);
378
+ int64_t* heap_ids = labels + qi * k;
379
+ float* heap_dis = distances + qi * k;
380
+ maxheap_heapify(k, heap_dis, heap_ids);
394
381
 
395
382
  if (search_type == ST_polysemous) {
396
-
397
383
  switch (pq.code_size) {
398
- case 4:
399
- n_pass += polysemous_inner_loop<HammingComputer4>
400
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
401
- break;
402
- case 8:
403
- n_pass += polysemous_inner_loop<HammingComputer8>
404
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
405
- break;
406
- case 16:
407
- n_pass += polysemous_inner_loop<HammingComputer16>
408
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
409
- break;
410
- case 32:
411
- n_pass += polysemous_inner_loop<HammingComputer32>
412
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
413
- break;
414
- case 20:
415
- n_pass += polysemous_inner_loop<HammingComputer20>
416
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
417
- break;
418
- default:
419
- if (pq.code_size % 8 == 0) {
420
- n_pass += polysemous_inner_loop<HammingComputerM8>
421
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
422
- } else if (pq.code_size % 4 == 0) {
423
- n_pass += polysemous_inner_loop<HammingComputerM4>
424
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
- } else {
426
- FAISS_THROW_FMT(
427
- "code size %zd not supported for polysemous",
428
- pq.code_size);
429
- }
430
- break;
384
+ case 4:
385
+ n_pass += polysemous_inner_loop<HammingComputer4>(
386
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
387
+ break;
388
+ case 8:
389
+ n_pass += polysemous_inner_loop<HammingComputer8>(
390
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
391
+ break;
392
+ case 16:
393
+ n_pass += polysemous_inner_loop<HammingComputer16>(
394
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
395
+ break;
396
+ case 32:
397
+ n_pass += polysemous_inner_loop<HammingComputer32>(
398
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
399
+ break;
400
+ case 20:
401
+ n_pass += polysemous_inner_loop<HammingComputer20>(
402
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
403
+ break;
404
+ default:
405
+ if (pq.code_size % 4 == 0) {
406
+ n_pass += polysemous_inner_loop<HammingComputerDefault>(
407
+ *this,
408
+ dis_table_qi,
409
+ q_code,
410
+ k,
411
+ heap_dis,
412
+ heap_ids);
413
+ } else {
414
+ FAISS_THROW_FMT(
415
+ "code size %zd not supported for polysemous",
416
+ pq.code_size);
417
+ }
418
+ break;
431
419
  }
432
420
  } else {
433
421
  switch (pq.code_size) {
434
- case 8:
435
- n_pass += polysemous_inner_loop<GenHammingComputer8>
436
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
437
- break;
438
- case 16:
439
- n_pass += polysemous_inner_loop<GenHammingComputer16>
440
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
441
- break;
442
- case 32:
443
- n_pass += polysemous_inner_loop<GenHammingComputer32>
444
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
445
- break;
446
- default:
447
- if (pq.code_size % 8 == 0) {
448
- n_pass += polysemous_inner_loop<GenHammingComputerM8>
449
- (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
450
- } else {
451
- FAISS_THROW_FMT(
452
- "code size %zd not supported for polysemous",
453
- pq.code_size);
454
- }
455
- break;
422
+ case 8:
423
+ n_pass += polysemous_inner_loop<GenHammingComputer8>(
424
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
+ break;
426
+ case 16:
427
+ n_pass += polysemous_inner_loop<GenHammingComputer16>(
428
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
429
+ break;
430
+ case 32:
431
+ n_pass += polysemous_inner_loop<GenHammingComputer32>(
432
+ *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
433
+ break;
434
+ default:
435
+ if (pq.code_size % 8 == 0) {
436
+ n_pass += polysemous_inner_loop<GenHammingComputerM8>(
437
+ *this,
438
+ dis_table_qi,
439
+ q_code,
440
+ k,
441
+ heap_dis,
442
+ heap_ids);
443
+ } else {
444
+ FAISS_THROW_FMT(
445
+ "code size %zd not supported for polysemous",
446
+ pq.code_size);
447
+ }
448
+ break;
456
449
  }
457
450
  }
458
- maxheap_reorder (k, heap_dis, heap_ids);
451
+ maxheap_reorder(k, heap_dis, heap_ids);
459
452
  }
460
453
 
461
454
  indexPQ_stats.nq += n;
462
455
  indexPQ_stats.ncode += n * ntotal;
463
456
  indexPQ_stats.n_hamming_pass += n_pass;
464
-
465
-
466
457
  }
467
458
 
468
-
469
459
  /* The standalone codec interface (just remaps to the PQ functions) */
470
- size_t IndexPQ::sa_code_size () const
471
- {
460
+ size_t IndexPQ::sa_code_size() const {
472
461
  return pq.code_size;
473
462
  }
474
463
 
475
- void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
476
- {
477
- pq.compute_codes (x, bytes, n);
464
+ void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
465
+ pq.compute_codes(x, bytes, n);
478
466
  }
479
467
 
480
- void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
481
- {
482
- pq.decode (bytes, x, n);
468
+ void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
469
+ pq.decode(bytes, x, n);
483
470
  }
484
471
 
485
-
486
-
487
-
488
472
  /*****************************************
489
473
  * Stats of IndexPQ codes
490
474
  ******************************************/
491
475
 
476
+ void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
477
+ const {
478
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
479
+ ScopeDeleter<uint8_t> del(q_codes);
492
480
 
481
+ pq.compute_codes(x, q_codes, n);
493
482
 
494
-
495
- void IndexPQ::hamming_distance_table (idx_t n, const float *x,
496
- int32_t *dis) const
497
- {
498
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
499
- ScopeDeleter<uint8_t> del (q_codes);
500
-
501
- pq.compute_codes (x, q_codes, n);
502
-
503
- hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
483
+ hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
504
484
  }
505
485
 
506
-
507
- void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
508
- idx_t nb, const float *xb,
509
- int64_t *hist)
510
- {
511
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
512
- FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
513
- FAISS_THROW_IF_NOT (pq.nbits == 8);
486
+ void IndexPQ::hamming_distance_histogram(
487
+ idx_t n,
488
+ const float* x,
489
+ idx_t nb,
490
+ const float* xb,
491
+ int64_t* hist) {
492
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
493
+ FAISS_THROW_IF_NOT(pq.code_size % 8 == 0);
494
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
514
495
 
515
496
  // Hamming embedding queries
516
- uint8_t * q_codes = new uint8_t [n * pq.code_size];
517
- ScopeDeleter <uint8_t> del (q_codes);
518
- pq.compute_codes (x, q_codes, n);
497
+ uint8_t* q_codes = new uint8_t[n * pq.code_size];
498
+ ScopeDeleter<uint8_t> del(q_codes);
499
+ pq.compute_codes(x, q_codes, n);
519
500
 
520
- uint8_t * b_codes ;
521
- ScopeDeleter <uint8_t> del_b_codes;
501
+ uint8_t* b_codes;
502
+ ScopeDeleter<uint8_t> del_b_codes;
522
503
 
523
504
  if (xb) {
524
- b_codes = new uint8_t [nb * pq.code_size];
525
- del_b_codes.set (b_codes);
526
- pq.compute_codes (xb, b_codes, nb);
505
+ b_codes = new uint8_t[nb * pq.code_size];
506
+ del_b_codes.set(b_codes);
507
+ pq.compute_codes(xb, b_codes, nb);
527
508
  } else {
528
509
  nb = ntotal;
529
510
  b_codes = codes.data();
530
511
  }
531
512
  int nbits = pq.M * pq.nbits;
532
- memset (hist, 0, sizeof(*hist) * (nbits + 1));
513
+ memset(hist, 0, sizeof(*hist) * (nbits + 1));
533
514
  size_t bs = 256;
534
515
 
535
516
  #pragma omp parallel
536
517
  {
537
- std::vector<int64_t> histi (nbits + 1);
538
- hamdis_t *distances = new hamdis_t [nb * bs];
539
- ScopeDeleter<hamdis_t> del (distances);
518
+ std::vector<int64_t> histi(nbits + 1);
519
+ hamdis_t* distances = new hamdis_t[nb * bs];
520
+ ScopeDeleter<hamdis_t> del(distances);
540
521
  #pragma omp for
541
522
  for (idx_t q0 = 0; q0 < n; q0 += bs) {
542
523
  // printf ("dis stats: %zd/%zd\n", q0, n);
543
524
  size_t q1 = q0 + bs;
544
- if (q1 > n) q1 = n;
525
+ if (q1 > n)
526
+ q1 = n;
545
527
 
546
- hammings (q_codes + q0 * pq.code_size, b_codes,
547
- q1 - q0, nb,
548
- pq.code_size, distances);
528
+ hammings(
529
+ q_codes + q0 * pq.code_size,
530
+ b_codes,
531
+ q1 - q0,
532
+ nb,
533
+ pq.code_size,
534
+ distances);
549
535
 
550
536
  for (size_t i = 0; i < nb * (q1 - q0); i++)
551
- histi [distances [i]]++;
537
+ histi[distances[i]]++;
552
538
  }
553
539
  #pragma omp critical
554
540
  {
@@ -556,28 +542,8 @@ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
556
542
  hist[i] += histi[i];
557
543
  }
558
544
  }
559
-
560
545
  }
561
546
 
562
-
563
-
564
-
565
-
566
-
567
-
568
-
569
-
570
-
571
-
572
-
573
-
574
-
575
-
576
-
577
-
578
-
579
-
580
-
581
547
  /*****************************************
582
548
  * MultiIndexQuantizer
583
549
  ******************************************/
@@ -586,90 +552,87 @@ namespace {
586
552
 
587
553
  template <typename T>
588
554
  struct PreSortedArray {
589
-
590
- const T * x;
555
+ const T* x;
591
556
  int N;
592
557
 
593
- explicit PreSortedArray (int N): N(N) {
594
- }
595
- void init (const T*x) {
558
+ explicit PreSortedArray(int N) : N(N) {}
559
+ void init(const T* x) {
596
560
  this->x = x;
597
561
  }
598
562
  // get smallest value
599
- T get_0 () {
563
+ T get_0() {
600
564
  return x[0];
601
565
  }
602
566
 
603
567
  // get delta between n-smallest and n-1 -smallest
604
- T get_diff (int n) {
568
+ T get_diff(int n) {
605
569
  return x[n] - x[n - 1];
606
570
  }
607
571
 
608
572
  // remap orders counted from smallest to indices in array
609
- int get_ord (int n) {
573
+ int get_ord(int n) {
610
574
  return n;
611
575
  }
612
-
613
576
  };
614
577
 
615
578
  template <typename T>
616
579
  struct ArgSort {
617
- const T * x;
618
- bool operator() (size_t i, size_t j) {
580
+ const T* x;
581
+ bool operator()(size_t i, size_t j) {
619
582
  return x[i] < x[j];
620
583
  }
621
584
  };
622
585
 
623
-
624
586
  /** Array that maintains a permutation of its elements so that the
625
587
  * array's elements are sorted
626
588
  */
627
589
  template <typename T>
628
590
  struct SortedArray {
629
- const T * x;
591
+ const T* x;
630
592
  int N;
631
593
  std::vector<int> perm;
632
594
 
633
- explicit SortedArray (int N) {
595
+ explicit SortedArray(int N) {
634
596
  this->N = N;
635
- perm.resize (N);
597
+ perm.resize(N);
636
598
  }
637
599
 
638
- void init (const T*x) {
600
+ void init(const T* x) {
639
601
  this->x = x;
640
602
  for (int n = 0; n < N; n++)
641
603
  perm[n] = n;
642
- ArgSort<T> cmp = {x };
643
- std::sort (perm.begin(), perm.end(), cmp);
604
+ ArgSort<T> cmp = {x};
605
+ std::sort(perm.begin(), perm.end(), cmp);
644
606
  }
645
607
 
646
608
  // get smallest value
647
- T get_0 () {
609
+ T get_0() {
648
610
  return x[perm[0]];
649
611
  }
650
612
 
651
613
  // get delta between n-smallest and n-1 -smallest
652
- T get_diff (int n) {
614
+ T get_diff(int n) {
653
615
  return x[perm[n]] - x[perm[n - 1]];
654
616
  }
655
617
 
656
618
  // remap orders counted from smallest to indices in array
657
- int get_ord (int n) {
619
+ int get_ord(int n) {
658
620
  return perm[n];
659
621
  }
660
622
  };
661
623
 
662
-
663
-
664
624
  /** Array has n values. Sort the k first ones and copy the other ones
665
625
  * into elements k..n-1
666
626
  */
667
627
  template <class C>
668
- void partial_sort (int k, int n,
669
- const typename C::T * vals, typename C::TI * perm) {
628
+ void partial_sort(
629
+ int k,
630
+ int n,
631
+ const typename C::T* vals,
632
+ typename C::TI* perm) {
670
633
  // insert first k elts in heap
671
634
  for (int i = 1; i < k; i++) {
672
- indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
635
+ indirect_heap_push<C>(i + 1, vals, perm, perm[i]);
673
636
  }
674
637
 
675
638
  // insert next n - k elts in heap
@@ -678,8 +641,8 @@ void partial_sort (int k, int n,
678
641
  typename C::TI top = perm[0];
679
642
 
680
643
  if (C::cmp(vals[top], vals[id])) {
681
- indirect_heap_pop<C> (k, vals, perm);
682
- indirect_heap_push<C> (k, vals, perm, id);
644
+ indirect_heap_pop<C>(k, vals, perm);
645
+ indirect_heap_push<C>(k, vals, perm, id);
683
646
  perm[i] = top;
684
647
  } else {
685
648
  // nothing, elt at i is good where it is.
@@ -689,7 +652,7 @@ void partial_sort (int k, int n,
689
652
  // order the k first elements in heap
690
653
  for (int i = k - 1; i > 0; i--) {
691
654
  typename C::TI top = perm[0];
692
- indirect_heap_pop<C> (i + 1, vals, perm);
655
+ indirect_heap_pop<C>(i + 1, vals, perm);
693
656
  perm[i] = top;
694
657
  }
695
658
  }
@@ -697,69 +660,67 @@ void partial_sort (int k, int n,
697
660
  /** same as SortedArray, but only the k first elements are sorted */
698
661
  template <typename T>
699
662
  struct SemiSortedArray {
700
- const T * x;
663
+ const T* x;
701
664
  int N;
702
665
 
703
666
  // type of the heap: CMax = sort ascending
704
667
  typedef CMax<T, int> HC;
705
668
  std::vector<int> perm;
706
669
 
707
- int k; // k elements are sorted
670
+ int k; // k elements are sorted
708
671
 
709
672
  int initial_k, k_factor;
710
673
 
711
- explicit SemiSortedArray (int N) {
674
+ explicit SemiSortedArray(int N) {
712
675
  this->N = N;
713
- perm.resize (N);
714
- perm.resize (N);
676
+ perm.resize(N);
677
+ perm.resize(N);
715
678
  initial_k = 3;
716
679
  k_factor = 4;
717
680
  }
718
681
 
719
- void init (const T*x) {
682
+ void init(const T* x) {
720
683
  this->x = x;
721
684
  for (int n = 0; n < N; n++)
722
685
  perm[n] = n;
723
686
  k = 0;
724
- grow (initial_k);
687
+ grow(initial_k);
725
688
  }
726
689
 
727
690
  /// grow the sorted part of the array to size next_k
728
- void grow (int next_k) {
691
+ void grow(int next_k) {
729
692
  if (next_k < N) {
730
- partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
693
+ partial_sort<HC>(next_k - k, N - k, x, &perm[k]);
731
694
  k = next_k;
732
695
  } else { // full sort of remainder of array
733
- ArgSort<T> cmp = {x };
734
- std::sort (perm.begin() + k, perm.end(), cmp);
696
+ ArgSort<T> cmp = {x};
697
+ std::sort(perm.begin() + k, perm.end(), cmp);
735
698
  k = N;
736
699
  }
737
700
  }
738
701
 
739
702
  // get smallest value
740
- T get_0 () {
703
+ T get_0() {
741
704
  return x[perm[0]];
742
705
  }
743
706
 
744
707
  // get delta between n-smallest and n-1 -smallest
745
- T get_diff (int n) {
708
+ T get_diff(int n) {
746
709
  if (n >= k) {
747
710
  // want to keep powers of 2 - 1
748
711
  int next_k = (k + 1) * k_factor - 1;
749
- grow (next_k);
712
+ grow(next_k);
750
713
  }
751
714
  return x[perm[n]] - x[perm[n - 1]];
752
715
  }
753
716
 
754
717
  // remap orders counted from smallest to indices in array
755
- int get_ord (int n) {
756
- assert (n < k);
718
+ int get_ord(int n) {
719
+ assert(n < k);
757
720
  return perm[n];
758
721
  }
759
722
  };
760
723
 
761
-
762
-
763
724
  /*****************************************
764
725
  * Find the k smallest sums of M terms, where each term is taken in a
765
726
  * table x of n values.
@@ -779,19 +740,19 @@ struct SemiSortedArray {
779
740
  * occasionally several t's are returned.
780
741
  *
781
742
  * @param x size M * n, values to add up
782
- * @parms k nb of results to retrieve
743
+ * @param k nb of results to retrieve
783
744
  * @param M nb of terms
784
745
  * @param n nb of distinct values
785
746
  * @param sums output, size k, sorted
786
- * @prarm terms output, size k, with encoding as above
747
+ * @param terms output, size k, with encoding as above
787
748
  *
788
749
  ******************************************/
789
750
  template <typename T, class SSA, bool use_seen>
790
751
  struct MinSumK {
791
- int K; ///< nb of sums to return
792
- int M; ///< nb of elements to sum up
752
+ int K; ///< nb of sums to return
753
+ int M; ///< nb of elements to sum up
793
754
  int nbit; ///< nb of bits to encode one entry
794
- int N; ///< nb of possible elements for each of the M terms
755
+ int N; ///< nb of possible elements for each of the M terms
795
756
 
796
757
  /** the heap.
797
758
  * We use a heap to maintain a queue of sums, with the associated
@@ -799,21 +760,20 @@ struct MinSumK {
799
760
  */
800
761
  typedef CMin<T, int64_t> HC;
801
762
  size_t heap_capacity, heap_size;
802
- T *bh_val;
803
- int64_t *bh_ids;
763
+ T* bh_val;
764
+ int64_t* bh_ids;
804
765
 
805
- std::vector <SSA> ssx;
766
+ std::vector<SSA> ssx;
806
767
 
807
768
  // all results get pushed several times. When there are ties, they
808
769
  // are popped interleaved with others, so it is not easy to
809
770
  // identify them. Therefore, this bit array just marks elements
810
771
  // that were seen before.
811
- std::vector <uint8_t> seen;
772
+ std::vector<uint8_t> seen;
812
773
 
813
- MinSumK (int K, int M, int nbit, int N):
814
- K(K), M(M), nbit(nbit), N(N) {
774
+ MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) {
815
775
  heap_capacity = K * M;
816
- assert (N <= (1 << nbit));
776
+ assert(N <= (1 << nbit));
817
777
 
818
778
  // we'll do k steps, each step pushes at most M vals
819
779
  bh_val = new T[heap_capacity];
@@ -821,29 +781,27 @@ struct MinSumK {
821
781
 
822
782
  if (use_seen) {
823
783
  int64_t n_ids = weight(M);
824
- seen.resize ((n_ids + 7) / 8);
784
+ seen.resize((n_ids + 7) / 8);
825
785
  }
826
786
 
827
787
  for (int m = 0; m < M; m++)
828
- ssx.push_back (SSA(N));
829
-
788
+ ssx.push_back(SSA(N));
830
789
  }
831
790
 
832
- int64_t weight (int i) {
791
+ int64_t weight(int i) {
833
792
  return 1 << (i * nbit);
834
793
  }
835
794
 
836
- bool is_seen (int64_t i) {
795
+ bool is_seen(int64_t i) {
837
796
  return (seen[i >> 3] >> (i & 7)) & 1;
838
797
  }
839
798
 
840
- void mark_seen (int64_t i) {
799
+ void mark_seen(int64_t i) {
841
800
  if (use_seen)
842
- seen [i >> 3] |= 1 << (i & 7);
801
+ seen[i >> 3] |= 1 << (i & 7);
843
802
  }
844
803
 
845
- void run (const T *x, int64_t ldx,
846
- T * sums, int64_t * terms) {
804
+ void run(const T* x, int64_t ldx, T* sums, int64_t* terms) {
847
805
  heap_size = 0;
848
806
 
849
807
  for (int m = 0; m < M; m++) {
@@ -854,38 +812,41 @@ struct MinSumK {
854
812
  { // initial result: take min for all elements
855
813
  T sum = 0;
856
814
  terms[0] = 0;
857
- mark_seen (0);
815
+ mark_seen(0);
858
816
  for (int m = 0; m < M; m++) {
859
817
  sum += ssx[m].get_0();
860
818
  }
861
819
  sums[0] = sum;
862
820
  for (int m = 0; m < M; m++) {
863
- heap_push<HC> (++heap_size, bh_val, bh_ids,
864
- sum + ssx[m].get_diff(1),
865
- weight(m));
821
+ heap_push<HC>(
822
+ ++heap_size,
823
+ bh_val,
824
+ bh_ids,
825
+ sum + ssx[m].get_diff(1),
826
+ weight(m));
866
827
  }
867
828
  }
868
829
 
869
830
  for (int k = 1; k < K; k++) {
870
831
  // pop smallest value from heap
871
- if (use_seen) {// skip already seen elements
872
- while (is_seen (bh_ids[0])) {
873
- assert (heap_size > 0);
874
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
832
+ if (use_seen) { // skip already seen elements
833
+ while (is_seen(bh_ids[0])) {
834
+ assert(heap_size > 0);
835
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
875
836
  }
876
837
  }
877
- assert (heap_size > 0);
838
+ assert(heap_size > 0);
878
839
 
879
840
  T sum = sums[k] = bh_val[0];
880
841
  int64_t ti = terms[k] = bh_ids[0];
881
842
 
882
843
  if (use_seen) {
883
- mark_seen (ti);
884
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
844
+ mark_seen(ti);
845
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
885
846
  } else {
886
847
  do {
887
- heap_pop<HC> (heap_size--, bh_val, bh_ids);
888
- } while (heap_size > 0 && bh_ids[0] == ti);
848
+ heap_pop<HC>(heap_size--, bh_val, bh_ids);
849
+ } while (heap_size > 0 && bh_ids[0] == ti);
889
850
  }
890
851
 
891
852
  // enqueue followers
@@ -893,9 +854,10 @@ struct MinSumK {
893
854
  for (int m = 0; m < M; m++) {
894
855
  int64_t n = ii & ((1L << nbit) - 1);
895
856
  ii >>= nbit;
896
- if (n + 1 >= N) continue;
857
+ if (n + 1 >= N)
858
+ continue;
897
859
 
898
- enqueue_follower (ti, m, n, sum);
860
+ enqueue_follower(ti, m, n, sum);
899
861
  }
900
862
  }
901
863
 
@@ -922,37 +884,29 @@ struct MinSumK {
922
884
  }
923
885
  }
924
886
 
925
-
926
- void enqueue_follower (int64_t ti, int m, int n, T sum) {
887
+ void enqueue_follower(int64_t ti, int m, int n, T sum) {
927
888
  T next_sum = sum + ssx[m].get_diff(n + 1);
928
889
  int64_t next_ti = ti + weight(m);
929
- heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
890
+ heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti);
930
891
  }
931
892
 
932
- ~MinSumK () {
933
- delete [] bh_ids;
934
- delete [] bh_val;
893
+ ~MinSumK() {
894
+ delete[] bh_ids;
895
+ delete[] bh_val;
935
896
  }
936
897
  };
937
898
 
938
899
  } // anonymous namespace
939
900
 
940
-
941
- MultiIndexQuantizer::MultiIndexQuantizer (int d,
942
- size_t M,
943
- size_t nbits):
944
- Index(d, METRIC_L2), pq(d, M, nbits)
945
- {
901
+ MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits)
902
+ : Index(d, METRIC_L2), pq(d, M, nbits) {
946
903
  is_trained = false;
947
904
  pq.verbose = verbose;
948
905
  }
949
906
 
950
-
951
-
952
- void MultiIndexQuantizer::train(idx_t n, const float *x)
953
- {
907
+ void MultiIndexQuantizer::train(idx_t n, const float* x) {
954
908
  pq.verbose = verbose;
955
- pq.train (n, x);
909
+ pq.train(n, x);
956
910
  is_trained = true;
957
911
  // count virtual elements in index
958
912
  ntotal = 1;
@@ -960,10 +914,16 @@ void MultiIndexQuantizer::train(idx_t n, const float *x)
960
914
  ntotal *= pq.ksub;
961
915
  }
962
916
 
917
+ void MultiIndexQuantizer::search(
918
+ idx_t n,
919
+ const float* x,
920
+ idx_t k,
921
+ float* distances,
922
+ idx_t* labels) const {
923
+ if (n == 0)
924
+ return;
963
925
 
964
- void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
965
- float *distances, idx_t *labels) const {
966
- if (n == 0) return;
926
+ FAISS_THROW_IF_NOT(k > 0);
967
927
 
968
928
  // the allocation just below can be severe...
969
929
  idx_t bs = 32768;
@@ -971,27 +931,28 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
971
931
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
972
932
  idx_t i1 = std::min(i0 + bs, n);
973
933
  if (verbose) {
974
- printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64 " / %" PRId64 "\n",
975
- i0, i1, n);
934
+ printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
935
+ " / %" PRId64 "\n",
936
+ i0,
937
+ i1,
938
+ n);
976
939
  }
977
- search (i1 - i0, x + i0 * d, k,
978
- distances + i0 * k,
979
- labels + i0 * k);
940
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
980
941
  }
981
942
  return;
982
943
  }
983
944
 
984
- float * dis_tables = new float [n * pq.ksub * pq.M];
985
- ScopeDeleter<float> del (dis_tables);
945
+ float* dis_tables = new float[n * pq.ksub * pq.M];
946
+ ScopeDeleter<float> del(dis_tables);
986
947
 
987
- pq.compute_distance_tables (n, x, dis_tables);
948
+ pq.compute_distance_tables(n, x, dis_tables);
988
949
 
989
950
  if (k == 1) {
990
951
  // simple version that just finds the min in each table
991
952
 
992
953
  #pragma omp parallel for
993
954
  for (int i = 0; i < n; i++) {
994
- const float * dis_table = dis_tables + i * pq.ksub * pq.M;
955
+ const float* dis_table = dis_tables + i * pq.ksub * pq.M;
995
956
  float dis = 0;
996
957
  idx_t label = 0;
997
958
 
@@ -1010,32 +971,27 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
1010
971
  dis_table += pq.ksub;
1011
972
  }
1012
973
 
1013
- distances [i] = dis;
1014
- labels [i] = label;
974
+ distances[i] = dis;
975
+ labels[i] = label;
1015
976
  }
1016
977
 
1017
-
1018
978
  } else {
1019
-
1020
- #pragma omp parallel if(n > 1)
979
+ #pragma omp parallel if (n > 1)
1021
980
  {
1022
- MinSumK <float, SemiSortedArray<float>, false>
1023
- msk(k, pq.M, pq.nbits, pq.ksub);
981
+ MinSumK<float, SemiSortedArray<float>, false> msk(
982
+ k, pq.M, pq.nbits, pq.ksub);
1024
983
  #pragma omp for
1025
984
  for (int i = 0; i < n; i++) {
1026
- msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
1027
- distances + i * k, labels + i * k);
1028
-
985
+ msk.run(dis_tables + i * pq.ksub * pq.M,
986
+ pq.ksub,
987
+ distances + i * k,
988
+ labels + i * k);
1029
989
  }
1030
990
  }
1031
991
  }
1032
-
1033
992
  }
1034
993
 
1035
-
1036
- void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1037
- {
1038
-
994
+ void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
1039
995
  int64_t jj = key;
1040
996
  for (int m = 0; m < pq.M; m++) {
1041
997
  int64_t n = jj & ((1L << pq.nbits) - 1);
@@ -1046,65 +1002,53 @@ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1046
1002
  }
1047
1003
 
1048
1004
  void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
1049
- FAISS_THROW_MSG(
1050
- "This index has virtual elements, "
1051
- "it does not support add");
1005
+ FAISS_THROW_MSG(
1006
+ "This index has virtual elements, "
1007
+ "it does not support add");
1052
1008
  }
1053
1009
 
1054
- void MultiIndexQuantizer::reset ()
1055
- {
1056
- FAISS_THROW_MSG ( "This index has virtual elements, "
1057
- "it does not support reset");
1010
+ void MultiIndexQuantizer::reset() {
1011
+ FAISS_THROW_MSG(
1012
+ "This index has virtual elements, "
1013
+ "it does not support reset");
1058
1014
  }
1059
1015
 
1060
-
1061
-
1062
-
1063
-
1064
-
1065
-
1066
-
1067
-
1068
-
1069
1016
  /*****************************************
1070
1017
  * MultiIndexQuantizer2
1071
1018
  ******************************************/
1072
1019
 
1073
-
1074
-
1075
- MultiIndexQuantizer2::MultiIndexQuantizer2 (
1076
- int d, size_t M, size_t nbits,
1077
- Index **indexes):
1078
- MultiIndexQuantizer (d, M, nbits)
1079
- {
1080
- assign_indexes.resize (M);
1020
+ MultiIndexQuantizer2::MultiIndexQuantizer2(
1021
+ int d,
1022
+ size_t M,
1023
+ size_t nbits,
1024
+ Index** indexes)
1025
+ : MultiIndexQuantizer(d, M, nbits) {
1026
+ assign_indexes.resize(M);
1081
1027
  for (int i = 0; i < M; i++) {
1082
1028
  FAISS_THROW_IF_NOT_MSG(
1083
- indexes[i]->d == pq.dsub,
1084
- "Provided sub-index has incorrect size");
1029
+ indexes[i]->d == pq.dsub,
1030
+ "Provided sub-index has incorrect size");
1085
1031
  assign_indexes[i] = indexes[i];
1086
1032
  }
1087
1033
  own_fields = false;
1088
1034
  }
1089
1035
 
1090
- MultiIndexQuantizer2::MultiIndexQuantizer2 (
1091
- int d, size_t nbits,
1092
- Index *assign_index_0,
1093
- Index *assign_index_1):
1094
- MultiIndexQuantizer (d, 2, nbits)
1095
- {
1036
+ MultiIndexQuantizer2::MultiIndexQuantizer2(
1037
+ int d,
1038
+ size_t nbits,
1039
+ Index* assign_index_0,
1040
+ Index* assign_index_1)
1041
+ : MultiIndexQuantizer(d, 2, nbits) {
1096
1042
  FAISS_THROW_IF_NOT_MSG(
1097
- assign_index_0->d == pq.dsub &&
1098
- assign_index_1->d == pq.dsub,
1043
+ assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub,
1099
1044
  "Provided sub-index has incorrect size");
1100
- assign_indexes.resize (2);
1101
- assign_indexes [0] = assign_index_0;
1102
- assign_indexes [1] = assign_index_1;
1045
+ assign_indexes.resize(2);
1046
+ assign_indexes[0] = assign_index_0;
1047
+ assign_indexes[1] = assign_index_1;
1103
1048
  own_fields = false;
1104
1049
  }
1105
1050
 
1106
- void MultiIndexQuantizer2::train(idx_t n, const float* x)
1107
- {
1051
+ void MultiIndexQuantizer2::train(idx_t n, const float* x) {
1108
1052
  MultiIndexQuantizer::train(n, x);
1109
1053
  // add centroids to sub-indexes
1110
1054
  for (int i = 0; i < pq.M; i++) {
@@ -1112,15 +1056,17 @@ void MultiIndexQuantizer2::train(idx_t n, const float* x)
1112
1056
  }
1113
1057
  }
1114
1058
 
1115
-
1116
1059
  void MultiIndexQuantizer2::search(
1117
- idx_t n, const float* x, idx_t K,
1118
- float* distances, idx_t* labels) const
1119
- {
1120
-
1121
- if (n == 0) return;
1060
+ idx_t n,
1061
+ const float* x,
1062
+ idx_t K,
1063
+ float* distances,
1064
+ idx_t* labels) const {
1065
+ if (n == 0)
1066
+ return;
1122
1067
 
1123
1068
  int k2 = std::min(K, int64_t(pq.ksub));
1069
+ FAISS_THROW_IF_NOT(k2);
1124
1070
 
1125
1071
  int64_t M = pq.M;
1126
1072
  int64_t dsub = pq.dsub, ksub = pq.ksub;
@@ -1131,8 +1077,8 @@ void MultiIndexQuantizer2::search(
1131
1077
  std::vector<float> xsub(n * dsub);
1132
1078
 
1133
1079
  for (int m = 0; m < M; m++) {
1134
- float *xdest = xsub.data();
1135
- const float *xsrc = x + m * dsub;
1080
+ float* xdest = xsub.data();
1081
+ const float* xsrc = x + m * dsub;
1136
1082
  for (int j = 0; j < n; j++) {
1137
1083
  memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1138
1084
  xsrc += d;
@@ -1140,14 +1086,12 @@ void MultiIndexQuantizer2::search(
1140
1086
  }
1141
1087
 
1142
1088
  assign_indexes[m]->search(
1143
- n, xsub.data(), k2,
1144
- &sub_dis[k2 * n * m],
1145
- &sub_ids[k2 * n * m]);
1089
+ n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]);
1146
1090
  }
1147
1091
 
1148
1092
  if (K == 1) {
1149
1093
  // simple version that just finds the min in each table
1150
- assert (k2 == 1);
1094
+ assert(k2 == 1);
1151
1095
 
1152
1096
  for (int i = 0; i < n; i++) {
1153
1097
  float dis = 0;
@@ -1159,30 +1103,28 @@ void MultiIndexQuantizer2::search(
1159
1103
  dis += vmin;
1160
1104
  label |= lmin << (m * pq.nbits);
1161
1105
  }
1162
- distances [i] = dis;
1163
- labels [i] = label;
1106
+ distances[i] = dis;
1107
+ labels[i] = label;
1164
1108
  }
1165
1109
 
1166
1110
  } else {
1167
-
1168
- #pragma omp parallel if(n > 1)
1111
+ #pragma omp parallel if (n > 1)
1169
1112
  {
1170
- MinSumK <float, PreSortedArray<float>, false>
1171
- msk(K, pq.M, pq.nbits, k2);
1113
+ MinSumK<float, PreSortedArray<float>, false> msk(
1114
+ K, pq.M, pq.nbits, k2);
1172
1115
  #pragma omp for
1173
1116
  for (int i = 0; i < n; i++) {
1174
- idx_t *li = labels + i * K;
1175
- msk.run (&sub_dis[i * k2], k2 * n,
1176
- distances + i * K, li);
1117
+ idx_t* li = labels + i * K;
1118
+ msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li);
1177
1119
 
1178
1120
  // remap ids
1179
1121
 
1180
- const idx_t *idmap0 = sub_ids.data() + i * k2;
1122
+ const idx_t* idmap0 = sub_ids.data() + i * k2;
1181
1123
  int64_t ld_idmap = k2 * n;
1182
1124
  int64_t mask1 = ksub - 1L;
1183
1125
 
1184
1126
  for (int k = 0; k < K; k++) {
1185
- const idx_t *idmap = idmap0;
1127
+ const idx_t* idmap = idmap0;
1186
1128
  int64_t vin = li[k];
1187
1129
  int64_t vout = 0;
1188
1130
  int bs = 0;
@@ -1200,5 +1142,4 @@ void MultiIndexQuantizer2::search(
1200
1142
  }
1201
1143
  }
1202
1144
 
1203
-
1204
1145
  } // namespace faiss