faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,68 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_INDEX_LATTICE_H
11
+ #define FAISS_INDEX_LATTICE_H
12
+
13
+
14
+ #include <vector>
15
+
16
+ #include <faiss/IndexIVF.h>
17
+ #include <faiss/impl/lattice_Zn.h>
18
+
19
+ namespace faiss {
20
+
21
+
22
+
23
+
24
+
25
+ /** Index that encodes a vector with a series of Zn lattice quantizers
26
+ */
27
+ struct IndexLattice: Index {
28
+
29
+ /// number of sub-vectors
30
+ int nsq;
31
+ /// dimension of sub-vectors
32
+ size_t dsq;
33
+
34
+ /// the lattice quantizer
35
+ ZnSphereCodecAlt zn_sphere_codec;
36
+
37
+ /// nb bits used to encode the scale, per subvector
38
+ int scale_nbit, lattice_nbit;
39
+ /// total, in bytes
40
+ size_t code_size;
41
+
42
+ /// mins and maxes of the vector norms, per subquantizer
43
+ std::vector<float> trained;
44
+
45
+ IndexLattice (idx_t d, int nsq, int scale_nbit, int r2);
46
+
47
+ void train(idx_t n, const float* x) override;
48
+
49
+ /* The standalone codec interface */
50
+ size_t sa_code_size () const override;
51
+
52
+ void sa_encode (idx_t n, const float *x,
53
+ uint8_t *bytes) const override;
54
+
55
+ void sa_decode (idx_t n, const uint8_t *bytes,
56
+ float *x) const override;
57
+
58
+ /// not implemented
59
+ void add(idx_t n, const float* x) override;
60
+ void search(idx_t n, const float* x, idx_t k,
61
+ float* distances, idx_t* labels) const override;
62
+ void reset() override;
63
+
64
+ };
65
+
66
+ } // namespace faiss
67
+
68
+ #endif
@@ -0,0 +1,1188 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexPQ.h>
11
+
12
+
13
+ #include <cstddef>
14
+ #include <cstring>
15
+ #include <cstdio>
16
+ #include <cmath>
17
+
18
+ #include <algorithm>
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/impl/AuxIndexStructures.h>
22
+ #include <faiss/utils/hamming.h>
23
+
24
+ namespace faiss {
25
+
26
+ /*********************************************************
27
+ * IndexPQ implementation
28
+ ********************************************************/
29
+
30
+
31
+ IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
32
+ Index(d, metric), pq(d, M, nbits)
33
+ {
34
+ is_trained = false;
35
+ do_polysemous_training = false;
36
+ polysemous_ht = nbits * M + 1;
37
+ search_type = ST_PQ;
38
+ encode_signs = false;
39
+ }
40
+
41
+ IndexPQ::IndexPQ ()
42
+ {
43
+ metric_type = METRIC_L2;
44
+ is_trained = false;
45
+ do_polysemous_training = false;
46
+ polysemous_ht = pq.nbits * pq.M + 1;
47
+ search_type = ST_PQ;
48
+ encode_signs = false;
49
+ }
50
+
51
+
52
+ void IndexPQ::train (idx_t n, const float *x)
53
+ {
54
+ if (!do_polysemous_training) { // standard training
55
+ pq.train(n, x);
56
+ } else {
57
+ idx_t ntrain_perm = polysemous_training.ntrain_permutation;
58
+
59
+ if (ntrain_perm > n / 4)
60
+ ntrain_perm = n / 4;
61
+ if (verbose) {
62
+ printf ("PQ training on %ld points, remains %ld points: "
63
+ "training polysemous on %s\n",
64
+ n - ntrain_perm, ntrain_perm,
65
+ ntrain_perm == 0 ? "centroids" : "these");
66
+ }
67
+ pq.train(n - ntrain_perm, x);
68
+
69
+ polysemous_training.optimize_pq_for_hamming (
70
+ pq, ntrain_perm, x + (n - ntrain_perm) * d);
71
+ }
72
+ is_trained = true;
73
+ }
74
+
75
+
76
+ void IndexPQ::add (idx_t n, const float *x)
77
+ {
78
+ FAISS_THROW_IF_NOT (is_trained);
79
+ codes.resize ((n + ntotal) * pq.code_size);
80
+ pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
81
+ ntotal += n;
82
+ }
83
+
84
+
85
+ size_t IndexPQ::remove_ids (const IDSelector & sel)
86
+ {
87
+ idx_t j = 0;
88
+ for (idx_t i = 0; i < ntotal; i++) {
89
+ if (sel.is_member (i)) {
90
+ // should be removed
91
+ } else {
92
+ if (i > j) {
93
+ memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
94
+ }
95
+ j++;
96
+ }
97
+ }
98
+ size_t nremove = ntotal - j;
99
+ if (nremove > 0) {
100
+ ntotal = j;
101
+ codes.resize (ntotal * pq.code_size);
102
+ }
103
+ return nremove;
104
+ }
105
+
106
+
107
+ void IndexPQ::reset()
108
+ {
109
+ codes.clear();
110
+ ntotal = 0;
111
+ }
112
+
113
+ void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
114
+ {
115
+ FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
116
+ for (idx_t i = 0; i < ni; i++) {
117
+ const uint8_t * code = &codes[(i0 + i) * pq.code_size];
118
+ pq.decode (code, recons + i * d);
119
+ }
120
+ }
121
+
122
+
123
+ void IndexPQ::reconstruct (idx_t key, float * recons) const
124
+ {
125
+ FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
126
+ pq.decode (&codes[key * pq.code_size], recons);
127
+ }
128
+
129
+
130
+ namespace {
131
+
132
+
133
+ struct PQDis: DistanceComputer {
134
+ size_t d;
135
+ Index::idx_t nb;
136
+ const uint8_t *codes;
137
+ size_t code_size;
138
+ const ProductQuantizer & pq;
139
+ const float *sdc;
140
+ std::vector<float> precomputed_table;
141
+ size_t ndis;
142
+
143
+ float operator () (idx_t i) override
144
+ {
145
+ const uint8_t *code = codes + i * code_size;
146
+ const float *dt = precomputed_table.data();
147
+ float accu = 0;
148
+ for (int j = 0; j < pq.M; j++) {
149
+ accu += dt[*code++];
150
+ dt += 256;
151
+ }
152
+ ndis++;
153
+ return accu;
154
+ }
155
+
156
+ float symmetric_dis(idx_t i, idx_t j) override
157
+ {
158
+ const float * sdci = sdc;
159
+ float accu = 0;
160
+ const uint8_t *codei = codes + i * code_size;
161
+ const uint8_t *codej = codes + j * code_size;
162
+
163
+ for (int l = 0; l < pq.M; l++) {
164
+ accu += sdci[(*codei++) + (*codej++) * 256];
165
+ sdci += 256 * 256;
166
+ }
167
+ return accu;
168
+ }
169
+
170
+ explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
171
+ : pq(storage.pq) {
172
+ precomputed_table.resize(pq.M * pq.ksub);
173
+ nb = storage.ntotal;
174
+ d = storage.d;
175
+ codes = storage.codes.data();
176
+ code_size = pq.code_size;
177
+ FAISS_ASSERT(pq.ksub == 256);
178
+ FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
179
+ sdc = pq.sdc_table.data();
180
+ ndis = 0;
181
+ }
182
+
183
+ void set_query(const float *x) override {
184
+ pq.compute_distance_table(x, precomputed_table.data());
185
+ }
186
+ };
187
+
188
+
189
+ } // namespace
190
+
191
+
192
+ DistanceComputer * IndexPQ::get_distance_computer() const {
193
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
194
+ return new PQDis(*this);
195
+ }
196
+
197
+
198
+ /*****************************************
199
+ * IndexPQ polysemous search routines
200
+ ******************************************/
201
+
202
+
203
+
204
+
205
+
206
+ void IndexPQ::search (idx_t n, const float *x, idx_t k,
207
+ float *distances, idx_t *labels) const
208
+ {
209
+ FAISS_THROW_IF_NOT (is_trained);
210
+ if (search_type == ST_PQ) { // Simple PQ search
211
+
212
+ if (metric_type == METRIC_L2) {
213
+ float_maxheap_array_t res = {
214
+ size_t(n), size_t(k), labels, distances };
215
+ pq.search (x, n, codes.data(), ntotal, &res, true);
216
+ } else {
217
+ float_minheap_array_t res = {
218
+ size_t(n), size_t(k), labels, distances };
219
+ pq.search_ip (x, n, codes.data(), ntotal, &res, true);
220
+ }
221
+ indexPQ_stats.nq += n;
222
+ indexPQ_stats.ncode += n * ntotal;
223
+
224
+ } else if (search_type == ST_polysemous ||
225
+ search_type == ST_polysemous_generalize) {
226
+
227
+ FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
228
+
229
+ search_core_polysemous (n, x, k, distances, labels);
230
+
231
+ } else { // code-to-code distances
232
+
233
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
234
+ ScopeDeleter<uint8_t> del (q_codes);
235
+
236
+
237
+ if (!encode_signs) {
238
+ pq.compute_codes (x, q_codes, n);
239
+ } else {
240
+ FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
241
+ memset (q_codes, 0, n * pq.code_size);
242
+ for (size_t i = 0; i < n; i++) {
243
+ const float *xi = x + i * d;
244
+ uint8_t *code = q_codes + i * pq.code_size;
245
+ for (int j = 0; j < d; j++)
246
+ if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
247
+ }
248
+ }
249
+
250
+ if (search_type == ST_SDC) {
251
+
252
+ float_maxheap_array_t res = {
253
+ size_t(n), size_t(k), labels, distances};
254
+
255
+ pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
256
+
257
+ } else {
258
+ int * idistances = new int [n * k];
259
+ ScopeDeleter<int> del (idistances);
260
+
261
+ int_maxheap_array_t res = {
262
+ size_t (n), size_t (k), labels, idistances};
263
+
264
+ if (search_type == ST_HE) {
265
+
266
+ hammings_knn_hc (&res, q_codes, codes.data(),
267
+ ntotal, pq.code_size, true);
268
+
269
+ } else if (search_type == ST_generalized_HE) {
270
+
271
+ generalized_hammings_knn_hc (&res, q_codes, codes.data(),
272
+ ntotal, pq.code_size, true);
273
+ }
274
+
275
+ // convert distances to floats
276
+ for (int i = 0; i < k * n; i++)
277
+ distances[i] = idistances[i];
278
+
279
+ }
280
+
281
+
282
+ indexPQ_stats.nq += n;
283
+ indexPQ_stats.ncode += n * ntotal;
284
+ }
285
+ }
286
+
287
+
288
+
289
+
290
+
291
+ void IndexPQStats::reset()
292
+ {
293
+ nq = ncode = n_hamming_pass = 0;
294
+ }
295
+
296
+ IndexPQStats indexPQ_stats;
297
+
298
+
299
+ template <class HammingComputer>
300
+ static size_t polysemous_inner_loop (
301
+ const IndexPQ & index,
302
+ const float *dis_table_qi, const uint8_t *q_code,
303
+ size_t k, float *heap_dis, int64_t *heap_ids)
304
+ {
305
+
306
+ int M = index.pq.M;
307
+ int code_size = index.pq.code_size;
308
+ int ksub = index.pq.ksub;
309
+ size_t ntotal = index.ntotal;
310
+ int ht = index.polysemous_ht;
311
+
312
+ const uint8_t *b_code = index.codes.data();
313
+
314
+ size_t n_pass_i = 0;
315
+
316
+ HammingComputer hc (q_code, code_size);
317
+
318
+ for (int64_t bi = 0; bi < ntotal; bi++) {
319
+ int hd = hc.hamming (b_code);
320
+
321
+ if (hd < ht) {
322
+ n_pass_i ++;
323
+
324
+ float dis = 0;
325
+ const float * dis_table = dis_table_qi;
326
+ for (int m = 0; m < M; m++) {
327
+ dis += dis_table [b_code[m]];
328
+ dis_table += ksub;
329
+ }
330
+
331
+ if (dis < heap_dis[0]) {
332
+ maxheap_pop (k, heap_dis, heap_ids);
333
+ maxheap_push (k, heap_dis, heap_ids, dis, bi);
334
+ }
335
+ }
336
+ b_code += code_size;
337
+ }
338
+ return n_pass_i;
339
+ }
340
+
341
+
342
+ void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
343
+ float *distances, idx_t *labels) const
344
+ {
345
+ FAISS_THROW_IF_NOT (pq.nbits == 8);
346
+
347
+ // PQ distance tables
348
+ float * dis_tables = new float [n * pq.ksub * pq.M];
349
+ ScopeDeleter<float> del (dis_tables);
350
+ pq.compute_distance_tables (n, x, dis_tables);
351
+
352
+ // Hamming embedding queries
353
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
354
+ ScopeDeleter<uint8_t> del2 (q_codes);
355
+
356
+ if (false) {
357
+ pq.compute_codes (x, q_codes, n);
358
+ } else {
359
+ #pragma omp parallel for
360
+ for (idx_t qi = 0; qi < n; qi++) {
361
+ pq.compute_code_from_distance_table
362
+ (dis_tables + qi * pq.M * pq.ksub,
363
+ q_codes + qi * pq.code_size);
364
+ }
365
+ }
366
+
367
+ size_t n_pass = 0;
368
+
369
+ #pragma omp parallel for reduction (+: n_pass)
370
+ for (idx_t qi = 0; qi < n; qi++) {
371
+ const uint8_t * q_code = q_codes + qi * pq.code_size;
372
+
373
+ const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
374
+
375
+ int64_t * heap_ids = labels + qi * k;
376
+ float *heap_dis = distances + qi * k;
377
+ maxheap_heapify (k, heap_dis, heap_ids);
378
+
379
+ if (search_type == ST_polysemous) {
380
+
381
+ switch (pq.code_size) {
382
+ case 4:
383
+ n_pass += polysemous_inner_loop<HammingComputer4>
384
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
385
+ break;
386
+ case 8:
387
+ n_pass += polysemous_inner_loop<HammingComputer8>
388
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
389
+ break;
390
+ case 16:
391
+ n_pass += polysemous_inner_loop<HammingComputer16>
392
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
393
+ break;
394
+ case 32:
395
+ n_pass += polysemous_inner_loop<HammingComputer32>
396
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
397
+ break;
398
+ case 20:
399
+ n_pass += polysemous_inner_loop<HammingComputer20>
400
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
401
+ break;
402
+ default:
403
+ if (pq.code_size % 8 == 0) {
404
+ n_pass += polysemous_inner_loop<HammingComputerM8>
405
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
406
+ } else if (pq.code_size % 4 == 0) {
407
+ n_pass += polysemous_inner_loop<HammingComputerM4>
408
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
409
+ } else {
410
+ FAISS_THROW_FMT(
411
+ "code size %zd not supported for polysemous",
412
+ pq.code_size);
413
+ }
414
+ break;
415
+ }
416
+ } else {
417
+ switch (pq.code_size) {
418
+ case 8:
419
+ n_pass += polysemous_inner_loop<GenHammingComputer8>
420
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
421
+ break;
422
+ case 16:
423
+ n_pass += polysemous_inner_loop<GenHammingComputer16>
424
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
+ break;
426
+ case 32:
427
+ n_pass += polysemous_inner_loop<GenHammingComputer32>
428
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
429
+ break;
430
+ default:
431
+ if (pq.code_size % 8 == 0) {
432
+ n_pass += polysemous_inner_loop<GenHammingComputerM8>
433
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
434
+ } else {
435
+ FAISS_THROW_FMT(
436
+ "code size %zd not supported for polysemous",
437
+ pq.code_size);
438
+ }
439
+ break;
440
+ }
441
+ }
442
+ maxheap_reorder (k, heap_dis, heap_ids);
443
+ }
444
+
445
+ indexPQ_stats.nq += n;
446
+ indexPQ_stats.ncode += n * ntotal;
447
+ indexPQ_stats.n_hamming_pass += n_pass;
448
+
449
+
450
+ }
451
+
452
+
453
+ /* The standalone codec interface (just remaps to the PQ functions) */
454
+ size_t IndexPQ::sa_code_size () const
455
+ {
456
+ return pq.code_size;
457
+ }
458
+
459
+ void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
460
+ {
461
+ pq.compute_codes (x, bytes, n);
462
+ }
463
+
464
+ void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
465
+ {
466
+ pq.decode (bytes, x, n);
467
+ }
468
+
469
+
470
+
471
+
472
+ /*****************************************
473
+ * Stats of IndexPQ codes
474
+ ******************************************/
475
+
476
+
477
+
478
+
479
+ void IndexPQ::hamming_distance_table (idx_t n, const float *x,
480
+ int32_t *dis) const
481
+ {
482
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
483
+ ScopeDeleter<uint8_t> del (q_codes);
484
+
485
+ pq.compute_codes (x, q_codes, n);
486
+
487
+ hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
488
+ }
489
+
490
+
491
+ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
492
+ idx_t nb, const float *xb,
493
+ int64_t *hist)
494
+ {
495
+ FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
496
+ FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
497
+ FAISS_THROW_IF_NOT (pq.nbits == 8);
498
+
499
+ // Hamming embedding queries
500
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
501
+ ScopeDeleter <uint8_t> del (q_codes);
502
+ pq.compute_codes (x, q_codes, n);
503
+
504
+ uint8_t * b_codes ;
505
+ ScopeDeleter <uint8_t> del_b_codes;
506
+
507
+ if (xb) {
508
+ b_codes = new uint8_t [nb * pq.code_size];
509
+ del_b_codes.set (b_codes);
510
+ pq.compute_codes (xb, b_codes, nb);
511
+ } else {
512
+ nb = ntotal;
513
+ b_codes = codes.data();
514
+ }
515
+ int nbits = pq.M * pq.nbits;
516
+ memset (hist, 0, sizeof(*hist) * (nbits + 1));
517
+ size_t bs = 256;
518
+
519
+ #pragma omp parallel
520
+ {
521
+ std::vector<int64_t> histi (nbits + 1);
522
+ hamdis_t *distances = new hamdis_t [nb * bs];
523
+ ScopeDeleter<hamdis_t> del (distances);
524
+ #pragma omp for
525
+ for (size_t q0 = 0; q0 < n; q0 += bs) {
526
+ // printf ("dis stats: %ld/%ld\n", q0, n);
527
+ size_t q1 = q0 + bs;
528
+ if (q1 > n) q1 = n;
529
+
530
+ hammings (q_codes + q0 * pq.code_size, b_codes,
531
+ q1 - q0, nb,
532
+ pq.code_size, distances);
533
+
534
+ for (size_t i = 0; i < nb * (q1 - q0); i++)
535
+ histi [distances [i]]++;
536
+ }
537
+ #pragma omp critical
538
+ {
539
+ for (int i = 0; i <= nbits; i++)
540
+ hist[i] += histi[i];
541
+ }
542
+ }
543
+
544
+ }
545
+
546
+
547
+
548
+
549
+
550
+
551
+
552
+
553
+
554
+
555
+
556
+
557
+
558
+
559
+
560
+
561
+
562
+
563
+
564
+
565
+ /*****************************************
566
+ * MultiIndexQuantizer
567
+ ******************************************/
568
+
569
+ namespace {
570
+
571
+ template <typename T>
572
+ struct PreSortedArray {
573
+
574
+ const T * x;
575
+ int N;
576
+
577
+ explicit PreSortedArray (int N): N(N) {
578
+ }
579
+ void init (const T*x) {
580
+ this->x = x;
581
+ }
582
+ // get smallest value
583
+ T get_0 () {
584
+ return x[0];
585
+ }
586
+
587
+ // get delta between n-smallest and n-1 -smallest
588
+ T get_diff (int n) {
589
+ return x[n] - x[n - 1];
590
+ }
591
+
592
+ // remap orders counted from smallest to indices in array
593
+ int get_ord (int n) {
594
+ return n;
595
+ }
596
+
597
+ };
598
+
599
+ template <typename T>
600
+ struct ArgSort {
601
+ const T * x;
602
+ bool operator() (size_t i, size_t j) {
603
+ return x[i] < x[j];
604
+ }
605
+ };
606
+
607
+
608
+ /** Array that maintains a permutation of its elements so that the
609
+ * array's elements are sorted
610
+ */
611
+ template <typename T>
612
+ struct SortedArray {
613
+ const T * x;
614
+ int N;
615
+ std::vector<int> perm;
616
+
617
+ explicit SortedArray (int N) {
618
+ this->N = N;
619
+ perm.resize (N);
620
+ }
621
+
622
+ void init (const T*x) {
623
+ this->x = x;
624
+ for (int n = 0; n < N; n++)
625
+ perm[n] = n;
626
+ ArgSort<T> cmp = {x };
627
+ std::sort (perm.begin(), perm.end(), cmp);
628
+ }
629
+
630
+ // get smallest value
631
+ T get_0 () {
632
+ return x[perm[0]];
633
+ }
634
+
635
+ // get delta between n-smallest and n-1 -smallest
636
+ T get_diff (int n) {
637
+ return x[perm[n]] - x[perm[n - 1]];
638
+ }
639
+
640
+ // remap orders counted from smallest to indices in array
641
+ int get_ord (int n) {
642
+ return perm[n];
643
+ }
644
+ };
645
+
646
+
647
+
648
+ /** Array has n values. Sort the k first ones and copy the other ones
649
+ * into elements k..n-1
650
+ */
651
+ template <class C>
652
+ void partial_sort (int k, int n,
653
+ const typename C::T * vals, typename C::TI * perm) {
654
+ // insert first k elts in heap
655
+ for (int i = 1; i < k; i++) {
656
+ indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
657
+ }
658
+
659
+ // insert next n - k elts in heap
660
+ for (int i = k; i < n; i++) {
661
+ typename C::TI id = perm[i];
662
+ typename C::TI top = perm[0];
663
+
664
+ if (C::cmp(vals[top], vals[id])) {
665
+ indirect_heap_pop<C> (k, vals, perm);
666
+ indirect_heap_push<C> (k, vals, perm, id);
667
+ perm[i] = top;
668
+ } else {
669
+ // nothing, elt at i is good where it is.
670
+ }
671
+ }
672
+
673
+ // order the k first elements in heap
674
+ for (int i = k - 1; i > 0; i--) {
675
+ typename C::TI top = perm[0];
676
+ indirect_heap_pop<C> (i + 1, vals, perm);
677
+ perm[i] = top;
678
+ }
679
+ }
680
+
681
+ /** same as SortedArray, but only the k first elements are sorted */
682
+ template <typename T>
683
+ struct SemiSortedArray {
684
+ const T * x;
685
+ int N;
686
+
687
+ // type of the heap: CMax = sort ascending
688
+ typedef CMax<T, int> HC;
689
+ std::vector<int> perm;
690
+
691
+ int k; // k elements are sorted
692
+
693
+ int initial_k, k_factor;
694
+
695
+ explicit SemiSortedArray (int N) {
696
+ this->N = N;
697
+ perm.resize (N);
698
+ perm.resize (N);
699
+ initial_k = 3;
700
+ k_factor = 4;
701
+ }
702
+
703
+ void init (const T*x) {
704
+ this->x = x;
705
+ for (int n = 0; n < N; n++)
706
+ perm[n] = n;
707
+ k = 0;
708
+ grow (initial_k);
709
+ }
710
+
711
+ /// grow the sorted part of the array to size next_k
712
+ void grow (int next_k) {
713
+ if (next_k < N) {
714
+ partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
715
+ k = next_k;
716
+ } else { // full sort of remainder of array
717
+ ArgSort<T> cmp = {x };
718
+ std::sort (perm.begin() + k, perm.end(), cmp);
719
+ k = N;
720
+ }
721
+ }
722
+
723
+ // get smallest value
724
+ T get_0 () {
725
+ return x[perm[0]];
726
+ }
727
+
728
+ // get delta between n-smallest and n-1 -smallest
729
+ T get_diff (int n) {
730
+ if (n >= k) {
731
+ // want to keep powers of 2 - 1
732
+ int next_k = (k + 1) * k_factor - 1;
733
+ grow (next_k);
734
+ }
735
+ return x[perm[n]] - x[perm[n - 1]];
736
+ }
737
+
738
+ // remap orders counted from smallest to indices in array
739
+ int get_ord (int n) {
740
+ assert (n < k);
741
+ return perm[n];
742
+ }
743
+ };
744
+
745
+
746
+
747
+ /*****************************************
748
+ * Find the k smallest sums of M terms, where each term is taken in a
749
+ * table x of n values.
750
+ *
751
+ * A combination of terms is encoded as a scalar 0 <= t < n^M. The
752
+ * combination t0 ... t(M-1) that correspond to the sum
753
+ *
754
+ * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
755
+ *
756
+ * is encoded as
757
+ *
758
+ * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
759
+ *
760
+ * MinSumK is an object rather than a function, so that storage can be
761
+ * re-used over several computations with the same sizes. use_seen is
762
+ * good when there may be ties in the x array and it is a concern if
763
+ * occasionally several t's are returned.
764
+ *
765
+ * @param x size M * n, values to add up
766
+ * @parms k nb of results to retrieve
767
+ * @param M nb of terms
768
+ * @param n nb of distinct values
769
+ * @param sums output, size k, sorted
770
+ * @prarm terms output, size k, with encoding as above
771
+ *
772
+ ******************************************/
773
+ template <typename T, class SSA, bool use_seen>
774
+ struct MinSumK {
775
+ int K; ///< nb of sums to return
776
+ int M; ///< nb of elements to sum up
777
+ int nbit; ///< nb of bits to encode one entry
778
+ int N; ///< nb of possible elements for each of the M terms
779
+
780
+ /** the heap.
781
+ * We use a heap to maintain a queue of sums, with the associated
782
+ * terms involved in the sum.
783
+ */
784
+ typedef CMin<T, int64_t> HC;
785
+ size_t heap_capacity, heap_size;
786
+ T *bh_val;
787
+ int64_t *bh_ids;
788
+
789
+ std::vector <SSA> ssx;
790
+
791
+ // all results get pushed several times. When there are ties, they
792
+ // are popped interleaved with others, so it is not easy to
793
+ // identify them. Therefore, this bit array just marks elements
794
+ // that were seen before.
795
+ std::vector <uint8_t> seen;
796
+
797
+ MinSumK (int K, int M, int nbit, int N):
798
+ K(K), M(M), nbit(nbit), N(N) {
799
+ heap_capacity = K * M;
800
+ assert (N <= (1 << nbit));
801
+
802
+ // we'll do k steps, each step pushes at most M vals
803
+ bh_val = new T[heap_capacity];
804
+ bh_ids = new int64_t[heap_capacity];
805
+
806
+ if (use_seen) {
807
+ int64_t n_ids = weight(M);
808
+ seen.resize ((n_ids + 7) / 8);
809
+ }
810
+
811
+ for (int m = 0; m < M; m++)
812
+ ssx.push_back (SSA(N));
813
+
814
+ }
815
+
816
+ int64_t weight (int i) {
817
+ return 1 << (i * nbit);
818
+ }
819
+
820
+ bool is_seen (int64_t i) {
821
+ return (seen[i >> 3] >> (i & 7)) & 1;
822
+ }
823
+
824
+ void mark_seen (int64_t i) {
825
+ if (use_seen)
826
+ seen [i >> 3] |= 1 << (i & 7);
827
+ }
828
+
829
+ void run (const T *x, int64_t ldx,
830
+ T * sums, int64_t * terms) {
831
+ heap_size = 0;
832
+
833
+ for (int m = 0; m < M; m++) {
834
+ ssx[m].init(x);
835
+ x += ldx;
836
+ }
837
+
838
+ { // intial result: take min for all elements
839
+ T sum = 0;
840
+ terms[0] = 0;
841
+ mark_seen (0);
842
+ for (int m = 0; m < M; m++) {
843
+ sum += ssx[m].get_0();
844
+ }
845
+ sums[0] = sum;
846
+ for (int m = 0; m < M; m++) {
847
+ heap_push<HC> (++heap_size, bh_val, bh_ids,
848
+ sum + ssx[m].get_diff(1),
849
+ weight(m));
850
+ }
851
+ }
852
+
853
+ for (int k = 1; k < K; k++) {
854
+ // pop smallest value from heap
855
+ if (use_seen) {// skip already seen elements
856
+ while (is_seen (bh_ids[0])) {
857
+ assert (heap_size > 0);
858
+ heap_pop<HC> (heap_size--, bh_val, bh_ids);
859
+ }
860
+ }
861
+ assert (heap_size > 0);
862
+
863
+ T sum = sums[k] = bh_val[0];
864
+ int64_t ti = terms[k] = bh_ids[0];
865
+
866
+ if (use_seen) {
867
+ mark_seen (ti);
868
+ heap_pop<HC> (heap_size--, bh_val, bh_ids);
869
+ } else {
870
+ do {
871
+ heap_pop<HC> (heap_size--, bh_val, bh_ids);
872
+ } while (heap_size > 0 && bh_ids[0] == ti);
873
+ }
874
+
875
+ // enqueue followers
876
+ int64_t ii = ti;
877
+ for (int m = 0; m < M; m++) {
878
+ int64_t n = ii & ((1L << nbit) - 1);
879
+ ii >>= nbit;
880
+ if (n + 1 >= N) continue;
881
+
882
+ enqueue_follower (ti, m, n, sum);
883
+ }
884
+ }
885
+
886
+ /*
887
+ for (int k = 0; k < K; k++)
888
+ for (int l = k + 1; l < K; l++)
889
+ assert (terms[k] != terms[l]);
890
+ */
891
+
892
+ // convert indices by applying permutation
893
+ for (int k = 0; k < K; k++) {
894
+ int64_t ii = terms[k];
895
+ if (use_seen) {
896
+ // clear seen for reuse at next loop
897
+ seen[ii >> 3] = 0;
898
+ }
899
+ int64_t ti = 0;
900
+ for (int m = 0; m < M; m++) {
901
+ int64_t n = ii & ((1L << nbit) - 1);
902
+ ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
903
+ ii >>= nbit;
904
+ }
905
+ terms[k] = ti;
906
+ }
907
+ }
908
+
909
+
910
+ void enqueue_follower (int64_t ti, int m, int n, T sum) {
911
+ T next_sum = sum + ssx[m].get_diff(n + 1);
912
+ int64_t next_ti = ti + weight(m);
913
+ heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
914
+ }
915
+
916
+ ~MinSumK () {
917
+ delete [] bh_ids;
918
+ delete [] bh_val;
919
+ }
920
+ };
921
+
922
+ } // anonymous namespace
923
+
924
+
925
+ MultiIndexQuantizer::MultiIndexQuantizer (int d,
926
+ size_t M,
927
+ size_t nbits):
928
+ Index(d, METRIC_L2), pq(d, M, nbits)
929
+ {
930
+ is_trained = false;
931
+ pq.verbose = verbose;
932
+ }
933
+
934
+
935
+
936
+ void MultiIndexQuantizer::train(idx_t n, const float *x)
937
+ {
938
+ pq.verbose = verbose;
939
+ pq.train (n, x);
940
+ is_trained = true;
941
+ // count virtual elements in index
942
+ ntotal = 1;
943
+ for (int m = 0; m < pq.M; m++)
944
+ ntotal *= pq.ksub;
945
+ }
946
+
947
+
948
+ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
949
+ float *distances, idx_t *labels) const {
950
+ if (n == 0) return;
951
+
952
+ // the allocation just below can be severe...
953
+ idx_t bs = 32768;
954
+ if (n > bs) {
955
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
956
+ idx_t i1 = std::min(i0 + bs, n);
957
+ if (verbose) {
958
+ printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n",
959
+ i0, i1, n);
960
+ }
961
+ search (i1 - i0, x + i0 * d, k,
962
+ distances + i0 * k,
963
+ labels + i0 * k);
964
+ }
965
+ return;
966
+ }
967
+
968
+ float * dis_tables = new float [n * pq.ksub * pq.M];
969
+ ScopeDeleter<float> del (dis_tables);
970
+
971
+ pq.compute_distance_tables (n, x, dis_tables);
972
+
973
+ if (k == 1) {
974
+ // simple version that just finds the min in each table
975
+
976
+ #pragma omp parallel for
977
+ for (int i = 0; i < n; i++) {
978
+ const float * dis_table = dis_tables + i * pq.ksub * pq.M;
979
+ float dis = 0;
980
+ idx_t label = 0;
981
+
982
+ for (int s = 0; s < pq.M; s++) {
983
+ float vmin = HUGE_VALF;
984
+ idx_t lmin = -1;
985
+
986
+ for (idx_t j = 0; j < pq.ksub; j++) {
987
+ if (dis_table[j] < vmin) {
988
+ vmin = dis_table[j];
989
+ lmin = j;
990
+ }
991
+ }
992
+ dis += vmin;
993
+ label |= lmin << (s * pq.nbits);
994
+ dis_table += pq.ksub;
995
+ }
996
+
997
+ distances [i] = dis;
998
+ labels [i] = label;
999
+ }
1000
+
1001
+
1002
+ } else {
1003
+
1004
+ #pragma omp parallel if(n > 1)
1005
+ {
1006
+ MinSumK <float, SemiSortedArray<float>, false>
1007
+ msk(k, pq.M, pq.nbits, pq.ksub);
1008
+ #pragma omp for
1009
+ for (int i = 0; i < n; i++) {
1010
+ msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
1011
+ distances + i * k, labels + i * k);
1012
+
1013
+ }
1014
+ }
1015
+ }
1016
+
1017
+ }
1018
+
1019
+
1020
+ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1021
+ {
1022
+
1023
+ int64_t jj = key;
1024
+ for (int m = 0; m < pq.M; m++) {
1025
+ int64_t n = jj & ((1L << pq.nbits) - 1);
1026
+ jj >>= pq.nbits;
1027
+ memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
1028
+ recons += pq.dsub;
1029
+ }
1030
+ }
1031
+
1032
+ void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
1033
+ FAISS_THROW_MSG(
1034
+ "This index has virtual elements, "
1035
+ "it does not support add");
1036
+ }
1037
+
1038
+ void MultiIndexQuantizer::reset ()
1039
+ {
1040
+ FAISS_THROW_MSG ( "This index has virtual elements, "
1041
+ "it does not support reset");
1042
+ }
1043
+
1044
+
1045
+
1046
+
1047
+
1048
+
1049
+
1050
+
1051
+
1052
+
1053
+ /*****************************************
1054
+ * MultiIndexQuantizer2
1055
+ ******************************************/
1056
+
1057
+
1058
+
1059
+ MultiIndexQuantizer2::MultiIndexQuantizer2 (
1060
+ int d, size_t M, size_t nbits,
1061
+ Index **indexes):
1062
+ MultiIndexQuantizer (d, M, nbits)
1063
+ {
1064
+ assign_indexes.resize (M);
1065
+ for (int i = 0; i < M; i++) {
1066
+ FAISS_THROW_IF_NOT_MSG(
1067
+ indexes[i]->d == pq.dsub,
1068
+ "Provided sub-index has incorrect size");
1069
+ assign_indexes[i] = indexes[i];
1070
+ }
1071
+ own_fields = false;
1072
+ }
1073
+
1074
+ MultiIndexQuantizer2::MultiIndexQuantizer2 (
1075
+ int d, size_t nbits,
1076
+ Index *assign_index_0,
1077
+ Index *assign_index_1):
1078
+ MultiIndexQuantizer (d, 2, nbits)
1079
+ {
1080
+ FAISS_THROW_IF_NOT_MSG(
1081
+ assign_index_0->d == pq.dsub &&
1082
+ assign_index_1->d == pq.dsub,
1083
+ "Provided sub-index has incorrect size");
1084
+ assign_indexes.resize (2);
1085
+ assign_indexes [0] = assign_index_0;
1086
+ assign_indexes [1] = assign_index_1;
1087
+ own_fields = false;
1088
+ }
1089
+
1090
+ void MultiIndexQuantizer2::train(idx_t n, const float* x)
1091
+ {
1092
+ MultiIndexQuantizer::train(n, x);
1093
+ // add centroids to sub-indexes
1094
+ for (int i = 0; i < pq.M; i++) {
1095
+ assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
1096
+ }
1097
+ }
1098
+
1099
+
1100
+ void MultiIndexQuantizer2::search(
1101
+ idx_t n, const float* x, idx_t K,
1102
+ float* distances, idx_t* labels) const
1103
+ {
1104
+
1105
+ if (n == 0) return;
1106
+
1107
+ int k2 = std::min(K, int64_t(pq.ksub));
1108
+
1109
+ int64_t M = pq.M;
1110
+ int64_t dsub = pq.dsub, ksub = pq.ksub;
1111
+
1112
+ // size (M, n, k2)
1113
+ std::vector<idx_t> sub_ids(n * M * k2);
1114
+ std::vector<float> sub_dis(n * M * k2);
1115
+ std::vector<float> xsub(n * dsub);
1116
+
1117
+ for (int m = 0; m < M; m++) {
1118
+ float *xdest = xsub.data();
1119
+ const float *xsrc = x + m * dsub;
1120
+ for (int j = 0; j < n; j++) {
1121
+ memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1122
+ xsrc += d;
1123
+ xdest += dsub;
1124
+ }
1125
+
1126
+ assign_indexes[m]->search(
1127
+ n, xsub.data(), k2,
1128
+ &sub_dis[k2 * n * m],
1129
+ &sub_ids[k2 * n * m]);
1130
+ }
1131
+
1132
+ if (K == 1) {
1133
+ // simple version that just finds the min in each table
1134
+ assert (k2 == 1);
1135
+
1136
+ for (int i = 0; i < n; i++) {
1137
+ float dis = 0;
1138
+ idx_t label = 0;
1139
+
1140
+ for (int m = 0; m < M; m++) {
1141
+ float vmin = sub_dis[i + m * n];
1142
+ idx_t lmin = sub_ids[i + m * n];
1143
+ dis += vmin;
1144
+ label |= lmin << (m * pq.nbits);
1145
+ }
1146
+ distances [i] = dis;
1147
+ labels [i] = label;
1148
+ }
1149
+
1150
+ } else {
1151
+
1152
+ #pragma omp parallel if(n > 1)
1153
+ {
1154
+ MinSumK <float, PreSortedArray<float>, false>
1155
+ msk(K, pq.M, pq.nbits, k2);
1156
+ #pragma omp for
1157
+ for (int i = 0; i < n; i++) {
1158
+ idx_t *li = labels + i * K;
1159
+ msk.run (&sub_dis[i * k2], k2 * n,
1160
+ distances + i * K, li);
1161
+
1162
+ // remap ids
1163
+
1164
+ const idx_t *idmap0 = sub_ids.data() + i * k2;
1165
+ int64_t ld_idmap = k2 * n;
1166
+ int64_t mask1 = ksub - 1L;
1167
+
1168
+ for (int k = 0; k < K; k++) {
1169
+ const idx_t *idmap = idmap0;
1170
+ int64_t vin = li[k];
1171
+ int64_t vout = 0;
1172
+ int bs = 0;
1173
+ for (int m = 0; m < M; m++) {
1174
+ int64_t s = vin & mask1;
1175
+ vin >>= pq.nbits;
1176
+ vout |= idmap[s] << bs;
1177
+ bs += pq.nbits;
1178
+ idmap += ld_idmap;
1179
+ }
1180
+ li[k] = vout;
1181
+ }
1182
+ }
1183
+ }
1184
+ }
1185
+ }
1186
+
1187
+
1188
+ } // namespace faiss