faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,68 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_INDEX_LATTICE_H
11
+ #define FAISS_INDEX_LATTICE_H
12
+
13
+
14
+ #include <vector>
15
+
16
+ #include <faiss/IndexIVF.h>
17
+ #include <faiss/impl/lattice_Zn.h>
18
+
19
+ namespace faiss {
20
+
21
+
22
+
23
+
24
+
25
+ /** Index that encodes a vector with a series of Zn lattice quantizers
26
+ */
27
+ struct IndexLattice: Index {
28
+
29
+ /// number of sub-vectors
30
+ int nsq;
31
+ /// dimension of sub-vectors
32
+ size_t dsq;
33
+
34
+ /// the lattice quantizer
35
+ ZnSphereCodecAlt zn_sphere_codec;
36
+
37
+ /// nb bits used to encode the scale, per subvector
38
+ int scale_nbit, lattice_nbit;
39
+ /// total, in bytes
40
+ size_t code_size;
41
+
42
+ /// mins and maxes of the vector norms, per subquantizer
43
+ std::vector<float> trained;
44
+
45
+ IndexLattice (idx_t d, int nsq, int scale_nbit, int r2);
46
+
47
+ void train(idx_t n, const float* x) override;
48
+
49
+ /* The standalone codec interface */
50
+ size_t sa_code_size () const override;
51
+
52
+ void sa_encode (idx_t n, const float *x,
53
+ uint8_t *bytes) const override;
54
+
55
+ void sa_decode (idx_t n, const uint8_t *bytes,
56
+ float *x) const override;
57
+
58
+ /// not implemented
59
+ void add(idx_t n, const float* x) override;
60
+ void search(idx_t n, const float* x, idx_t k,
61
+ float* distances, idx_t* labels) const override;
62
+ void reset() override;
63
+
64
+ };
65
+
66
+ } // namespace faiss
67
+
68
+ #endif
@@ -0,0 +1,1188 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexPQ.h>
11
+
12
+
13
+ #include <cstddef>
14
+ #include <cstring>
15
+ #include <cstdio>
16
+ #include <cmath>
17
+
18
+ #include <algorithm>
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/impl/AuxIndexStructures.h>
22
+ #include <faiss/utils/hamming.h>
23
+
24
+ namespace faiss {
25
+
26
+ /*********************************************************
27
+ * IndexPQ implementation
28
+ ********************************************************/
29
+
30
+
31
+ IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
32
+ Index(d, metric), pq(d, M, nbits)
33
+ {
34
+ is_trained = false;
35
+ do_polysemous_training = false;
36
+ polysemous_ht = nbits * M + 1;
37
+ search_type = ST_PQ;
38
+ encode_signs = false;
39
+ }
40
+
41
+ IndexPQ::IndexPQ ()
42
+ {
43
+ metric_type = METRIC_L2;
44
+ is_trained = false;
45
+ do_polysemous_training = false;
46
+ polysemous_ht = pq.nbits * pq.M + 1;
47
+ search_type = ST_PQ;
48
+ encode_signs = false;
49
+ }
50
+
51
+
52
+ void IndexPQ::train (idx_t n, const float *x)
53
+ {
54
+ if (!do_polysemous_training) { // standard training
55
+ pq.train(n, x);
56
+ } else {
57
+ idx_t ntrain_perm = polysemous_training.ntrain_permutation;
58
+
59
+ if (ntrain_perm > n / 4)
60
+ ntrain_perm = n / 4;
61
+ if (verbose) {
62
+ printf ("PQ training on %ld points, remains %ld points: "
63
+ "training polysemous on %s\n",
64
+ n - ntrain_perm, ntrain_perm,
65
+ ntrain_perm == 0 ? "centroids" : "these");
66
+ }
67
+ pq.train(n - ntrain_perm, x);
68
+
69
+ polysemous_training.optimize_pq_for_hamming (
70
+ pq, ntrain_perm, x + (n - ntrain_perm) * d);
71
+ }
72
+ is_trained = true;
73
+ }
74
+
75
+
76
+ void IndexPQ::add (idx_t n, const float *x)
77
+ {
78
+ FAISS_THROW_IF_NOT (is_trained);
79
+ codes.resize ((n + ntotal) * pq.code_size);
80
+ pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
81
+ ntotal += n;
82
+ }
83
+
84
+
85
+ size_t IndexPQ::remove_ids (const IDSelector & sel)
86
+ {
87
+ idx_t j = 0;
88
+ for (idx_t i = 0; i < ntotal; i++) {
89
+ if (sel.is_member (i)) {
90
+ // should be removed
91
+ } else {
92
+ if (i > j) {
93
+ memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
94
+ }
95
+ j++;
96
+ }
97
+ }
98
+ size_t nremove = ntotal - j;
99
+ if (nremove > 0) {
100
+ ntotal = j;
101
+ codes.resize (ntotal * pq.code_size);
102
+ }
103
+ return nremove;
104
+ }
105
+
106
+
107
+ void IndexPQ::reset()
108
+ {
109
+ codes.clear();
110
+ ntotal = 0;
111
+ }
112
+
113
+ void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
114
+ {
115
+ FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
116
+ for (idx_t i = 0; i < ni; i++) {
117
+ const uint8_t * code = &codes[(i0 + i) * pq.code_size];
118
+ pq.decode (code, recons + i * d);
119
+ }
120
+ }
121
+
122
+
123
+ void IndexPQ::reconstruct (idx_t key, float * recons) const
124
+ {
125
+ FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
126
+ pq.decode (&codes[key * pq.code_size], recons);
127
+ }
128
+
129
+
130
+ namespace {
131
+
132
+
133
+ struct PQDis: DistanceComputer {
134
+ size_t d;
135
+ Index::idx_t nb;
136
+ const uint8_t *codes;
137
+ size_t code_size;
138
+ const ProductQuantizer & pq;
139
+ const float *sdc;
140
+ std::vector<float> precomputed_table;
141
+ size_t ndis;
142
+
143
+ float operator () (idx_t i) override
144
+ {
145
+ const uint8_t *code = codes + i * code_size;
146
+ const float *dt = precomputed_table.data();
147
+ float accu = 0;
148
+ for (int j = 0; j < pq.M; j++) {
149
+ accu += dt[*code++];
150
+ dt += 256;
151
+ }
152
+ ndis++;
153
+ return accu;
154
+ }
155
+
156
+ float symmetric_dis(idx_t i, idx_t j) override
157
+ {
158
+ const float * sdci = sdc;
159
+ float accu = 0;
160
+ const uint8_t *codei = codes + i * code_size;
161
+ const uint8_t *codej = codes + j * code_size;
162
+
163
+ for (int l = 0; l < pq.M; l++) {
164
+ accu += sdci[(*codei++) + (*codej++) * 256];
165
+ sdci += 256 * 256;
166
+ }
167
+ return accu;
168
+ }
169
+
170
+ explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
171
+ : pq(storage.pq) {
172
+ precomputed_table.resize(pq.M * pq.ksub);
173
+ nb = storage.ntotal;
174
+ d = storage.d;
175
+ codes = storage.codes.data();
176
+ code_size = pq.code_size;
177
+ FAISS_ASSERT(pq.ksub == 256);
178
+ FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
179
+ sdc = pq.sdc_table.data();
180
+ ndis = 0;
181
+ }
182
+
183
+ void set_query(const float *x) override {
184
+ pq.compute_distance_table(x, precomputed_table.data());
185
+ }
186
+ };
187
+
188
+
189
+ } // namespace
190
+
191
+
192
+ DistanceComputer * IndexPQ::get_distance_computer() const {
193
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
194
+ return new PQDis(*this);
195
+ }
196
+
197
+
198
+ /*****************************************
199
+ * IndexPQ polysemous search routines
200
+ ******************************************/
201
+
202
+
203
+
204
+
205
+
206
+ void IndexPQ::search (idx_t n, const float *x, idx_t k,
207
+ float *distances, idx_t *labels) const
208
+ {
209
+ FAISS_THROW_IF_NOT (is_trained);
210
+ if (search_type == ST_PQ) { // Simple PQ search
211
+
212
+ if (metric_type == METRIC_L2) {
213
+ float_maxheap_array_t res = {
214
+ size_t(n), size_t(k), labels, distances };
215
+ pq.search (x, n, codes.data(), ntotal, &res, true);
216
+ } else {
217
+ float_minheap_array_t res = {
218
+ size_t(n), size_t(k), labels, distances };
219
+ pq.search_ip (x, n, codes.data(), ntotal, &res, true);
220
+ }
221
+ indexPQ_stats.nq += n;
222
+ indexPQ_stats.ncode += n * ntotal;
223
+
224
+ } else if (search_type == ST_polysemous ||
225
+ search_type == ST_polysemous_generalize) {
226
+
227
+ FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
228
+
229
+ search_core_polysemous (n, x, k, distances, labels);
230
+
231
+ } else { // code-to-code distances
232
+
233
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
234
+ ScopeDeleter<uint8_t> del (q_codes);
235
+
236
+
237
+ if (!encode_signs) {
238
+ pq.compute_codes (x, q_codes, n);
239
+ } else {
240
+ FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
241
+ memset (q_codes, 0, n * pq.code_size);
242
+ for (size_t i = 0; i < n; i++) {
243
+ const float *xi = x + i * d;
244
+ uint8_t *code = q_codes + i * pq.code_size;
245
+ for (int j = 0; j < d; j++)
246
+ if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
247
+ }
248
+ }
249
+
250
+ if (search_type == ST_SDC) {
251
+
252
+ float_maxheap_array_t res = {
253
+ size_t(n), size_t(k), labels, distances};
254
+
255
+ pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
256
+
257
+ } else {
258
+ int * idistances = new int [n * k];
259
+ ScopeDeleter<int> del (idistances);
260
+
261
+ int_maxheap_array_t res = {
262
+ size_t (n), size_t (k), labels, idistances};
263
+
264
+ if (search_type == ST_HE) {
265
+
266
+ hammings_knn_hc (&res, q_codes, codes.data(),
267
+ ntotal, pq.code_size, true);
268
+
269
+ } else if (search_type == ST_generalized_HE) {
270
+
271
+ generalized_hammings_knn_hc (&res, q_codes, codes.data(),
272
+ ntotal, pq.code_size, true);
273
+ }
274
+
275
+ // convert distances to floats
276
+ for (int i = 0; i < k * n; i++)
277
+ distances[i] = idistances[i];
278
+
279
+ }
280
+
281
+
282
+ indexPQ_stats.nq += n;
283
+ indexPQ_stats.ncode += n * ntotal;
284
+ }
285
+ }
286
+
287
+
288
+
289
+
290
+
291
+ void IndexPQStats::reset()
292
+ {
293
+ nq = ncode = n_hamming_pass = 0;
294
+ }
295
+
296
+ IndexPQStats indexPQ_stats;
297
+
298
+
299
+ template <class HammingComputer>
300
+ static size_t polysemous_inner_loop (
301
+ const IndexPQ & index,
302
+ const float *dis_table_qi, const uint8_t *q_code,
303
+ size_t k, float *heap_dis, int64_t *heap_ids)
304
+ {
305
+
306
+ int M = index.pq.M;
307
+ int code_size = index.pq.code_size;
308
+ int ksub = index.pq.ksub;
309
+ size_t ntotal = index.ntotal;
310
+ int ht = index.polysemous_ht;
311
+
312
+ const uint8_t *b_code = index.codes.data();
313
+
314
+ size_t n_pass_i = 0;
315
+
316
+ HammingComputer hc (q_code, code_size);
317
+
318
+ for (int64_t bi = 0; bi < ntotal; bi++) {
319
+ int hd = hc.hamming (b_code);
320
+
321
+ if (hd < ht) {
322
+ n_pass_i ++;
323
+
324
+ float dis = 0;
325
+ const float * dis_table = dis_table_qi;
326
+ for (int m = 0; m < M; m++) {
327
+ dis += dis_table [b_code[m]];
328
+ dis_table += ksub;
329
+ }
330
+
331
+ if (dis < heap_dis[0]) {
332
+ maxheap_pop (k, heap_dis, heap_ids);
333
+ maxheap_push (k, heap_dis, heap_ids, dis, bi);
334
+ }
335
+ }
336
+ b_code += code_size;
337
+ }
338
+ return n_pass_i;
339
+ }
340
+
341
+
342
+ void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
343
+ float *distances, idx_t *labels) const
344
+ {
345
+ FAISS_THROW_IF_NOT (pq.nbits == 8);
346
+
347
+ // PQ distance tables
348
+ float * dis_tables = new float [n * pq.ksub * pq.M];
349
+ ScopeDeleter<float> del (dis_tables);
350
+ pq.compute_distance_tables (n, x, dis_tables);
351
+
352
+ // Hamming embedding queries
353
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
354
+ ScopeDeleter<uint8_t> del2 (q_codes);
355
+
356
+ if (false) {
357
+ pq.compute_codes (x, q_codes, n);
358
+ } else {
359
+ #pragma omp parallel for
360
+ for (idx_t qi = 0; qi < n; qi++) {
361
+ pq.compute_code_from_distance_table
362
+ (dis_tables + qi * pq.M * pq.ksub,
363
+ q_codes + qi * pq.code_size);
364
+ }
365
+ }
366
+
367
+ size_t n_pass = 0;
368
+
369
+ #pragma omp parallel for reduction (+: n_pass)
370
+ for (idx_t qi = 0; qi < n; qi++) {
371
+ const uint8_t * q_code = q_codes + qi * pq.code_size;
372
+
373
+ const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
374
+
375
+ int64_t * heap_ids = labels + qi * k;
376
+ float *heap_dis = distances + qi * k;
377
+ maxheap_heapify (k, heap_dis, heap_ids);
378
+
379
+ if (search_type == ST_polysemous) {
380
+
381
+ switch (pq.code_size) {
382
+ case 4:
383
+ n_pass += polysemous_inner_loop<HammingComputer4>
384
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
385
+ break;
386
+ case 8:
387
+ n_pass += polysemous_inner_loop<HammingComputer8>
388
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
389
+ break;
390
+ case 16:
391
+ n_pass += polysemous_inner_loop<HammingComputer16>
392
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
393
+ break;
394
+ case 32:
395
+ n_pass += polysemous_inner_loop<HammingComputer32>
396
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
397
+ break;
398
+ case 20:
399
+ n_pass += polysemous_inner_loop<HammingComputer20>
400
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
401
+ break;
402
+ default:
403
+ if (pq.code_size % 8 == 0) {
404
+ n_pass += polysemous_inner_loop<HammingComputerM8>
405
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
406
+ } else if (pq.code_size % 4 == 0) {
407
+ n_pass += polysemous_inner_loop<HammingComputerM4>
408
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
409
+ } else {
410
+ FAISS_THROW_FMT(
411
+ "code size %zd not supported for polysemous",
412
+ pq.code_size);
413
+ }
414
+ break;
415
+ }
416
+ } else {
417
+ switch (pq.code_size) {
418
+ case 8:
419
+ n_pass += polysemous_inner_loop<GenHammingComputer8>
420
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
421
+ break;
422
+ case 16:
423
+ n_pass += polysemous_inner_loop<GenHammingComputer16>
424
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
+ break;
426
+ case 32:
427
+ n_pass += polysemous_inner_loop<GenHammingComputer32>
428
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
429
+ break;
430
+ default:
431
+ if (pq.code_size % 8 == 0) {
432
+ n_pass += polysemous_inner_loop<GenHammingComputerM8>
433
+ (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
434
+ } else {
435
+ FAISS_THROW_FMT(
436
+ "code size %zd not supported for polysemous",
437
+ pq.code_size);
438
+ }
439
+ break;
440
+ }
441
+ }
442
+ maxheap_reorder (k, heap_dis, heap_ids);
443
+ }
444
+
445
+ indexPQ_stats.nq += n;
446
+ indexPQ_stats.ncode += n * ntotal;
447
+ indexPQ_stats.n_hamming_pass += n_pass;
448
+
449
+
450
+ }
451
+
452
+
453
+ /* The standalone codec interface (just remaps to the PQ functions) */
454
+ size_t IndexPQ::sa_code_size () const
455
+ {
456
+ return pq.code_size;
457
+ }
458
+
459
+ void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
460
+ {
461
+ pq.compute_codes (x, bytes, n);
462
+ }
463
+
464
+ void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
465
+ {
466
+ pq.decode (bytes, x, n);
467
+ }
468
+
469
+
470
+
471
+
472
+ /*****************************************
473
+ * Stats of IndexPQ codes
474
+ ******************************************/
475
+
476
+
477
+
478
+
479
+ void IndexPQ::hamming_distance_table (idx_t n, const float *x,
480
+ int32_t *dis) const
481
+ {
482
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
483
+ ScopeDeleter<uint8_t> del (q_codes);
484
+
485
+ pq.compute_codes (x, q_codes, n);
486
+
487
+ hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
488
+ }
489
+
490
+
491
+ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
492
+ idx_t nb, const float *xb,
493
+ int64_t *hist)
494
+ {
495
+ FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
496
+ FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
497
+ FAISS_THROW_IF_NOT (pq.nbits == 8);
498
+
499
+ // Hamming embedding queries
500
+ uint8_t * q_codes = new uint8_t [n * pq.code_size];
501
+ ScopeDeleter <uint8_t> del (q_codes);
502
+ pq.compute_codes (x, q_codes, n);
503
+
504
+ uint8_t * b_codes ;
505
+ ScopeDeleter <uint8_t> del_b_codes;
506
+
507
+ if (xb) {
508
+ b_codes = new uint8_t [nb * pq.code_size];
509
+ del_b_codes.set (b_codes);
510
+ pq.compute_codes (xb, b_codes, nb);
511
+ } else {
512
+ nb = ntotal;
513
+ b_codes = codes.data();
514
+ }
515
+ int nbits = pq.M * pq.nbits;
516
+ memset (hist, 0, sizeof(*hist) * (nbits + 1));
517
+ size_t bs = 256;
518
+
519
+ #pragma omp parallel
520
+ {
521
+ std::vector<int64_t> histi (nbits + 1);
522
+ hamdis_t *distances = new hamdis_t [nb * bs];
523
+ ScopeDeleter<hamdis_t> del (distances);
524
+ #pragma omp for
525
+ for (size_t q0 = 0; q0 < n; q0 += bs) {
526
+ // printf ("dis stats: %ld/%ld\n", q0, n);
527
+ size_t q1 = q0 + bs;
528
+ if (q1 > n) q1 = n;
529
+
530
+ hammings (q_codes + q0 * pq.code_size, b_codes,
531
+ q1 - q0, nb,
532
+ pq.code_size, distances);
533
+
534
+ for (size_t i = 0; i < nb * (q1 - q0); i++)
535
+ histi [distances [i]]++;
536
+ }
537
+ #pragma omp critical
538
+ {
539
+ for (int i = 0; i <= nbits; i++)
540
+ hist[i] += histi[i];
541
+ }
542
+ }
543
+
544
+ }
545
+
546
+
547
+
548
+
549
+
550
+
551
+
552
+
553
+
554
+
555
+
556
+
557
+
558
+
559
+
560
+
561
+
562
+
563
+
564
+
565
+ /*****************************************
566
+ * MultiIndexQuantizer
567
+ ******************************************/
568
+
569
+ namespace {
570
+
571
+ template <typename T>
572
+ struct PreSortedArray {
573
+
574
+ const T * x;
575
+ int N;
576
+
577
+ explicit PreSortedArray (int N): N(N) {
578
+ }
579
+ void init (const T*x) {
580
+ this->x = x;
581
+ }
582
+ // get smallest value
583
+ T get_0 () {
584
+ return x[0];
585
+ }
586
+
587
+ // get delta between n-smallest and n-1 -smallest
588
+ T get_diff (int n) {
589
+ return x[n] - x[n - 1];
590
+ }
591
+
592
+ // remap orders counted from smallest to indices in array
593
+ int get_ord (int n) {
594
+ return n;
595
+ }
596
+
597
+ };
598
+
599
+ template <typename T>
600
+ struct ArgSort {
601
+ const T * x;
602
+ bool operator() (size_t i, size_t j) {
603
+ return x[i] < x[j];
604
+ }
605
+ };
606
+
607
+
608
+ /** Array that maintains a permutation of its elements so that the
609
+ * array's elements are sorted
610
+ */
611
+ template <typename T>
612
+ struct SortedArray {
613
+ const T * x;
614
+ int N;
615
+ std::vector<int> perm;
616
+
617
+ explicit SortedArray (int N) {
618
+ this->N = N;
619
+ perm.resize (N);
620
+ }
621
+
622
+ void init (const T*x) {
623
+ this->x = x;
624
+ for (int n = 0; n < N; n++)
625
+ perm[n] = n;
626
+ ArgSort<T> cmp = {x };
627
+ std::sort (perm.begin(), perm.end(), cmp);
628
+ }
629
+
630
+ // get smallest value
631
+ T get_0 () {
632
+ return x[perm[0]];
633
+ }
634
+
635
+ // get delta between n-smallest and n-1 -smallest
636
+ T get_diff (int n) {
637
+ return x[perm[n]] - x[perm[n - 1]];
638
+ }
639
+
640
+ // remap orders counted from smallest to indices in array
641
+ int get_ord (int n) {
642
+ return perm[n];
643
+ }
644
+ };
645
+
646
+
647
+
648
+ /** Array has n values. Sort the k first ones and copy the other ones
649
+ * into elements k..n-1
650
+ */
651
+ template <class C>
652
+ void partial_sort (int k, int n,
653
+ const typename C::T * vals, typename C::TI * perm) {
654
+ // insert first k elts in heap
655
+ for (int i = 1; i < k; i++) {
656
+ indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
657
+ }
658
+
659
+ // insert next n - k elts in heap
660
+ for (int i = k; i < n; i++) {
661
+ typename C::TI id = perm[i];
662
+ typename C::TI top = perm[0];
663
+
664
+ if (C::cmp(vals[top], vals[id])) {
665
+ indirect_heap_pop<C> (k, vals, perm);
666
+ indirect_heap_push<C> (k, vals, perm, id);
667
+ perm[i] = top;
668
+ } else {
669
+ // nothing, elt at i is good where it is.
670
+ }
671
+ }
672
+
673
+ // order the k first elements in heap
674
+ for (int i = k - 1; i > 0; i--) {
675
+ typename C::TI top = perm[0];
676
+ indirect_heap_pop<C> (i + 1, vals, perm);
677
+ perm[i] = top;
678
+ }
679
+ }
680
+
681
+ /** same as SortedArray, but only the k first elements are sorted */
682
+ template <typename T>
683
+ struct SemiSortedArray {
684
+ const T * x;
685
+ int N;
686
+
687
+ // type of the heap: CMax = sort ascending
688
+ typedef CMax<T, int> HC;
689
+ std::vector<int> perm;
690
+
691
+ int k; // k elements are sorted
692
+
693
+ int initial_k, k_factor;
694
+
695
+ explicit SemiSortedArray (int N) {
696
+ this->N = N;
697
+ perm.resize (N);
698
+ perm.resize (N);
699
+ initial_k = 3;
700
+ k_factor = 4;
701
+ }
702
+
703
+ void init (const T*x) {
704
+ this->x = x;
705
+ for (int n = 0; n < N; n++)
706
+ perm[n] = n;
707
+ k = 0;
708
+ grow (initial_k);
709
+ }
710
+
711
+ /// grow the sorted part of the array to size next_k
712
+ void grow (int next_k) {
713
+ if (next_k < N) {
714
+ partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
715
+ k = next_k;
716
+ } else { // full sort of remainder of array
717
+ ArgSort<T> cmp = {x };
718
+ std::sort (perm.begin() + k, perm.end(), cmp);
719
+ k = N;
720
+ }
721
+ }
722
+
723
+ // get smallest value
724
+ T get_0 () {
725
+ return x[perm[0]];
726
+ }
727
+
728
+ // get delta between n-smallest and n-1 -smallest
729
+ T get_diff (int n) {
730
+ if (n >= k) {
731
+ // want to keep powers of 2 - 1
732
+ int next_k = (k + 1) * k_factor - 1;
733
+ grow (next_k);
734
+ }
735
+ return x[perm[n]] - x[perm[n - 1]];
736
+ }
737
+
738
+ // remap orders counted from smallest to indices in array
739
+ int get_ord (int n) {
740
+ assert (n < k);
741
+ return perm[n];
742
+ }
743
+ };
744
+
745
+
746
+
747
+ /*****************************************
748
+ * Find the k smallest sums of M terms, where each term is taken in a
749
+ * table x of n values.
750
+ *
751
+ * A combination of terms is encoded as a scalar 0 <= t < n^M. The
752
+ * combination t0 ... t(M-1) that correspond to the sum
753
+ *
754
+ * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
755
+ *
756
+ * is encoded as
757
+ *
758
+ * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
759
+ *
760
+ * MinSumK is an object rather than a function, so that storage can be
761
+ * re-used over several computations with the same sizes. use_seen is
762
+ * good when there may be ties in the x array and it is a concern if
763
+ * occasionally several t's are returned.
764
+ *
765
+ * @param x size M * n, values to add up
766
+ * @parms k nb of results to retrieve
767
+ * @param M nb of terms
768
+ * @param n nb of distinct values
769
+ * @param sums output, size k, sorted
770
+ * @prarm terms output, size k, with encoding as above
771
+ *
772
+ ******************************************/
773
+ template <typename T, class SSA, bool use_seen>
774
+ struct MinSumK {
775
+ int K; ///< nb of sums to return
776
+ int M; ///< nb of elements to sum up
777
+ int nbit; ///< nb of bits to encode one entry
778
+ int N; ///< nb of possible elements for each of the M terms
779
+
780
+ /** the heap.
781
+ * We use a heap to maintain a queue of sums, with the associated
782
+ * terms involved in the sum.
783
+ */
784
+ typedef CMin<T, int64_t> HC;
785
+ size_t heap_capacity, heap_size;
786
+ T *bh_val;
787
+ int64_t *bh_ids;
788
+
789
+ std::vector <SSA> ssx;
790
+
791
+ // all results get pushed several times. When there are ties, they
792
+ // are popped interleaved with others, so it is not easy to
793
+ // identify them. Therefore, this bit array just marks elements
794
+ // that were seen before.
795
+ std::vector <uint8_t> seen;
796
+
797
+ MinSumK (int K, int M, int nbit, int N):
798
+ K(K), M(M), nbit(nbit), N(N) {
799
+ heap_capacity = K * M;
800
+ assert (N <= (1 << nbit));
801
+
802
+ // we'll do k steps, each step pushes at most M vals
803
+ bh_val = new T[heap_capacity];
804
+ bh_ids = new int64_t[heap_capacity];
805
+
806
+ if (use_seen) {
807
+ int64_t n_ids = weight(M);
808
+ seen.resize ((n_ids + 7) / 8);
809
+ }
810
+
811
+ for (int m = 0; m < M; m++)
812
+ ssx.push_back (SSA(N));
813
+
814
+ }
815
+
816
+ int64_t weight (int i) {
817
+ return 1 << (i * nbit);
818
+ }
819
+
820
+ bool is_seen (int64_t i) {
821
+ return (seen[i >> 3] >> (i & 7)) & 1;
822
+ }
823
+
824
+ void mark_seen (int64_t i) {
825
+ if (use_seen)
826
+ seen [i >> 3] |= 1 << (i & 7);
827
+ }
828
+
829
+ void run (const T *x, int64_t ldx,
830
+ T * sums, int64_t * terms) {
831
+ heap_size = 0;
832
+
833
+ for (int m = 0; m < M; m++) {
834
+ ssx[m].init(x);
835
+ x += ldx;
836
+ }
837
+
838
+ { // intial result: take min for all elements
839
+ T sum = 0;
840
+ terms[0] = 0;
841
+ mark_seen (0);
842
+ for (int m = 0; m < M; m++) {
843
+ sum += ssx[m].get_0();
844
+ }
845
+ sums[0] = sum;
846
+ for (int m = 0; m < M; m++) {
847
+ heap_push<HC> (++heap_size, bh_val, bh_ids,
848
+ sum + ssx[m].get_diff(1),
849
+ weight(m));
850
+ }
851
+ }
852
+
853
+ for (int k = 1; k < K; k++) {
854
+ // pop smallest value from heap
855
+ if (use_seen) {// skip already seen elements
856
+ while (is_seen (bh_ids[0])) {
857
+ assert (heap_size > 0);
858
+ heap_pop<HC> (heap_size--, bh_val, bh_ids);
859
+ }
860
+ }
861
+ assert (heap_size > 0);
862
+
863
+ T sum = sums[k] = bh_val[0];
864
+ int64_t ti = terms[k] = bh_ids[0];
865
+
866
+ if (use_seen) {
867
+ mark_seen (ti);
868
+ heap_pop<HC> (heap_size--, bh_val, bh_ids);
869
+ } else {
870
+ do {
871
+ heap_pop<HC> (heap_size--, bh_val, bh_ids);
872
+ } while (heap_size > 0 && bh_ids[0] == ti);
873
+ }
874
+
875
+ // enqueue followers
876
+ int64_t ii = ti;
877
+ for (int m = 0; m < M; m++) {
878
+ int64_t n = ii & ((1L << nbit) - 1);
879
+ ii >>= nbit;
880
+ if (n + 1 >= N) continue;
881
+
882
+ enqueue_follower (ti, m, n, sum);
883
+ }
884
+ }
885
+
886
+ /*
887
+ for (int k = 0; k < K; k++)
888
+ for (int l = k + 1; l < K; l++)
889
+ assert (terms[k] != terms[l]);
890
+ */
891
+
892
+ // convert indices by applying permutation
893
+ for (int k = 0; k < K; k++) {
894
+ int64_t ii = terms[k];
895
+ if (use_seen) {
896
+ // clear seen for reuse at next loop
897
+ seen[ii >> 3] = 0;
898
+ }
899
+ int64_t ti = 0;
900
+ for (int m = 0; m < M; m++) {
901
+ int64_t n = ii & ((1L << nbit) - 1);
902
+ ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
903
+ ii >>= nbit;
904
+ }
905
+ terms[k] = ti;
906
+ }
907
+ }
908
+
909
+
910
+ void enqueue_follower (int64_t ti, int m, int n, T sum) {
911
+ T next_sum = sum + ssx[m].get_diff(n + 1);
912
+ int64_t next_ti = ti + weight(m);
913
+ heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
914
+ }
915
+
916
+ ~MinSumK () {
917
+ delete [] bh_ids;
918
+ delete [] bh_val;
919
+ }
920
+ };
921
+
922
+ } // anonymous namespace
923
+
924
+
925
+ MultiIndexQuantizer::MultiIndexQuantizer (int d,
926
+ size_t M,
927
+ size_t nbits):
928
+ Index(d, METRIC_L2), pq(d, M, nbits)
929
+ {
930
+ is_trained = false;
931
+ pq.verbose = verbose;
932
+ }
933
+
934
+
935
+
936
+ void MultiIndexQuantizer::train(idx_t n, const float *x)
937
+ {
938
+ pq.verbose = verbose;
939
+ pq.train (n, x);
940
+ is_trained = true;
941
+ // count virtual elements in index
942
+ ntotal = 1;
943
+ for (int m = 0; m < pq.M; m++)
944
+ ntotal *= pq.ksub;
945
+ }
946
+
947
+
948
+ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
949
+ float *distances, idx_t *labels) const {
950
+ if (n == 0) return;
951
+
952
+ // the allocation just below can be severe...
953
+ idx_t bs = 32768;
954
+ if (n > bs) {
955
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
956
+ idx_t i1 = std::min(i0 + bs, n);
957
+ if (verbose) {
958
+ printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n",
959
+ i0, i1, n);
960
+ }
961
+ search (i1 - i0, x + i0 * d, k,
962
+ distances + i0 * k,
963
+ labels + i0 * k);
964
+ }
965
+ return;
966
+ }
967
+
968
+ float * dis_tables = new float [n * pq.ksub * pq.M];
969
+ ScopeDeleter<float> del (dis_tables);
970
+
971
+ pq.compute_distance_tables (n, x, dis_tables);
972
+
973
+ if (k == 1) {
974
+ // simple version that just finds the min in each table
975
+
976
+ #pragma omp parallel for
977
+ for (int i = 0; i < n; i++) {
978
+ const float * dis_table = dis_tables + i * pq.ksub * pq.M;
979
+ float dis = 0;
980
+ idx_t label = 0;
981
+
982
+ for (int s = 0; s < pq.M; s++) {
983
+ float vmin = HUGE_VALF;
984
+ idx_t lmin = -1;
985
+
986
+ for (idx_t j = 0; j < pq.ksub; j++) {
987
+ if (dis_table[j] < vmin) {
988
+ vmin = dis_table[j];
989
+ lmin = j;
990
+ }
991
+ }
992
+ dis += vmin;
993
+ label |= lmin << (s * pq.nbits);
994
+ dis_table += pq.ksub;
995
+ }
996
+
997
+ distances [i] = dis;
998
+ labels [i] = label;
999
+ }
1000
+
1001
+
1002
+ } else {
1003
+
1004
+ #pragma omp parallel if(n > 1)
1005
+ {
1006
+ MinSumK <float, SemiSortedArray<float>, false>
1007
+ msk(k, pq.M, pq.nbits, pq.ksub);
1008
+ #pragma omp for
1009
+ for (int i = 0; i < n; i++) {
1010
+ msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
1011
+ distances + i * k, labels + i * k);
1012
+
1013
+ }
1014
+ }
1015
+ }
1016
+
1017
+ }
1018
+
1019
+
1020
+ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
1021
+ {
1022
+
1023
+ int64_t jj = key;
1024
+ for (int m = 0; m < pq.M; m++) {
1025
+ int64_t n = jj & ((1L << pq.nbits) - 1);
1026
+ jj >>= pq.nbits;
1027
+ memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
1028
+ recons += pq.dsub;
1029
+ }
1030
+ }
1031
+
1032
+ void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
1033
+ FAISS_THROW_MSG(
1034
+ "This index has virtual elements, "
1035
+ "it does not support add");
1036
+ }
1037
+
1038
+ void MultiIndexQuantizer::reset ()
1039
+ {
1040
+ FAISS_THROW_MSG ( "This index has virtual elements, "
1041
+ "it does not support reset");
1042
+ }
1043
+
1044
+
1045
+
1046
+
1047
+
1048
+
1049
+
1050
+
1051
+
1052
+
1053
+ /*****************************************
1054
+ * MultiIndexQuantizer2
1055
+ ******************************************/
1056
+
1057
+
1058
+
1059
+ MultiIndexQuantizer2::MultiIndexQuantizer2 (
1060
+ int d, size_t M, size_t nbits,
1061
+ Index **indexes):
1062
+ MultiIndexQuantizer (d, M, nbits)
1063
+ {
1064
+ assign_indexes.resize (M);
1065
+ for (int i = 0; i < M; i++) {
1066
+ FAISS_THROW_IF_NOT_MSG(
1067
+ indexes[i]->d == pq.dsub,
1068
+ "Provided sub-index has incorrect size");
1069
+ assign_indexes[i] = indexes[i];
1070
+ }
1071
+ own_fields = false;
1072
+ }
1073
+
1074
+ MultiIndexQuantizer2::MultiIndexQuantizer2 (
1075
+ int d, size_t nbits,
1076
+ Index *assign_index_0,
1077
+ Index *assign_index_1):
1078
+ MultiIndexQuantizer (d, 2, nbits)
1079
+ {
1080
+ FAISS_THROW_IF_NOT_MSG(
1081
+ assign_index_0->d == pq.dsub &&
1082
+ assign_index_1->d == pq.dsub,
1083
+ "Provided sub-index has incorrect size");
1084
+ assign_indexes.resize (2);
1085
+ assign_indexes [0] = assign_index_0;
1086
+ assign_indexes [1] = assign_index_1;
1087
+ own_fields = false;
1088
+ }
1089
+
1090
+ void MultiIndexQuantizer2::train(idx_t n, const float* x)
1091
+ {
1092
+ MultiIndexQuantizer::train(n, x);
1093
+ // add centroids to sub-indexes
1094
+ for (int i = 0; i < pq.M; i++) {
1095
+ assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
1096
+ }
1097
+ }
1098
+
1099
+
1100
+ void MultiIndexQuantizer2::search(
1101
+ idx_t n, const float* x, idx_t K,
1102
+ float* distances, idx_t* labels) const
1103
+ {
1104
+
1105
+ if (n == 0) return;
1106
+
1107
+ int k2 = std::min(K, int64_t(pq.ksub));
1108
+
1109
+ int64_t M = pq.M;
1110
+ int64_t dsub = pq.dsub, ksub = pq.ksub;
1111
+
1112
+ // size (M, n, k2)
1113
+ std::vector<idx_t> sub_ids(n * M * k2);
1114
+ std::vector<float> sub_dis(n * M * k2);
1115
+ std::vector<float> xsub(n * dsub);
1116
+
1117
+ for (int m = 0; m < M; m++) {
1118
+ float *xdest = xsub.data();
1119
+ const float *xsrc = x + m * dsub;
1120
+ for (int j = 0; j < n; j++) {
1121
+ memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1122
+ xsrc += d;
1123
+ xdest += dsub;
1124
+ }
1125
+
1126
+ assign_indexes[m]->search(
1127
+ n, xsub.data(), k2,
1128
+ &sub_dis[k2 * n * m],
1129
+ &sub_ids[k2 * n * m]);
1130
+ }
1131
+
1132
+ if (K == 1) {
1133
+ // simple version that just finds the min in each table
1134
+ assert (k2 == 1);
1135
+
1136
+ for (int i = 0; i < n; i++) {
1137
+ float dis = 0;
1138
+ idx_t label = 0;
1139
+
1140
+ for (int m = 0; m < M; m++) {
1141
+ float vmin = sub_dis[i + m * n];
1142
+ idx_t lmin = sub_ids[i + m * n];
1143
+ dis += vmin;
1144
+ label |= lmin << (m * pq.nbits);
1145
+ }
1146
+ distances [i] = dis;
1147
+ labels [i] = label;
1148
+ }
1149
+
1150
+ } else {
1151
+
1152
+ #pragma omp parallel if(n > 1)
1153
+ {
1154
+ MinSumK <float, PreSortedArray<float>, false>
1155
+ msk(K, pq.M, pq.nbits, k2);
1156
+ #pragma omp for
1157
+ for (int i = 0; i < n; i++) {
1158
+ idx_t *li = labels + i * K;
1159
+ msk.run (&sub_dis[i * k2], k2 * n,
1160
+ distances + i * K, li);
1161
+
1162
+ // remap ids
1163
+
1164
+ const idx_t *idmap0 = sub_ids.data() + i * k2;
1165
+ int64_t ld_idmap = k2 * n;
1166
+ int64_t mask1 = ksub - 1L;
1167
+
1168
+ for (int k = 0; k < K; k++) {
1169
+ const idx_t *idmap = idmap0;
1170
+ int64_t vin = li[k];
1171
+ int64_t vout = 0;
1172
+ int bs = 0;
1173
+ for (int m = 0; m < M; m++) {
1174
+ int64_t s = vin & mask1;
1175
+ vin >>= pq.nbits;
1176
+ vout |= idmap[s] << bs;
1177
+ bs += pq.nbits;
1178
+ idmap += ld_idmap;
1179
+ }
1180
+ li[k] = vout;
1181
+ }
1182
+ }
1183
+ }
1184
+ }
1185
+ }
1186
+
1187
+
1188
+ } // namespace faiss