faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,158 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED
11
+ #define FAISS_POLYSEMOUS_TRAINING_INCLUDED
12
+
13
+
14
+ #include <faiss/impl/ProductQuantizer.h>
15
+
16
+
17
+ namespace faiss {
18
+
19
+
20
+ /// parameters used for the simulated annealing method
21
+ struct SimulatedAnnealingParameters {
22
+
23
+ // optimization parameters
24
+ double init_temperature; // init probaility of accepting a bad swap
25
+ double temperature_decay; // at each iteration the temp is multiplied by this
26
+ int n_iter; // nb of iterations
27
+ int n_redo; // nb of runs of the simulation
28
+ int seed; // random seed
29
+ int verbose;
30
+ bool only_bit_flips; // restrict permutation changes to bit flips
31
+ bool init_random; // intialize with a random permutation (not identity)
32
+
33
+ // set reasonable defaults
34
+ SimulatedAnnealingParameters ();
35
+
36
+ };
37
+
38
+
39
+ /// abstract class for the loss function
40
+ struct PermutationObjective {
41
+
42
+ int n;
43
+
44
+ virtual double compute_cost (const int *perm) const = 0;
45
+
46
+ // what would the cost update be if iw and jw were swapped?
47
+ // default implementation just computes both and computes the difference
48
+ virtual double cost_update (const int *perm, int iw, int jw) const;
49
+
50
+ virtual ~PermutationObjective () {}
51
+ };
52
+
53
+
54
+ struct ReproduceDistancesObjective : PermutationObjective {
55
+
56
+ double dis_weight_factor;
57
+
58
+ static double sqr (double x) { return x * x; }
59
+
60
+ // weihgting of distances: it is more important to reproduce small
61
+ // distances well
62
+ double dis_weight (double x) const;
63
+
64
+ std::vector<double> source_dis; ///< "real" corrected distances (size n^2)
65
+ const double * target_dis; ///< wanted distances (size n^2)
66
+ std::vector<double> weights; ///< weights for each distance (size n^2)
67
+
68
+ double get_source_dis (int i, int j) const;
69
+
70
+ // cost = quadratic difference between actual distance and Hamming distance
71
+ double compute_cost(const int* perm) const override;
72
+
73
+ // what would the cost update be if iw and jw were swapped?
74
+ // computed in O(n) instead of O(n^2) for the full re-computation
75
+ double cost_update(const int* perm, int iw, int jw) const override;
76
+
77
+ ReproduceDistancesObjective (
78
+ int n,
79
+ const double *source_dis_in,
80
+ const double *target_dis_in,
81
+ double dis_weight_factor);
82
+
83
+ static void compute_mean_stdev (const double *tab, size_t n2,
84
+ double *mean_out, double *stddev_out);
85
+
86
+ void set_affine_target_dis (const double *source_dis_in);
87
+
88
+ ~ReproduceDistancesObjective() override {}
89
+ };
90
+
91
+ struct RandomGenerator;
92
+
93
+ /// Simulated annealing optimization algorithm for permutations.
94
+ struct SimulatedAnnealingOptimizer: SimulatedAnnealingParameters {
95
+
96
+ PermutationObjective *obj;
97
+ int n; ///< size of the permutation
98
+ FILE *logfile; /// logs values of the cost function
99
+
100
+ SimulatedAnnealingOptimizer (PermutationObjective *obj,
101
+ const SimulatedAnnealingParameters &p);
102
+ RandomGenerator *rnd;
103
+
104
+ /// remember intial cost of optimization
105
+ double init_cost;
106
+
107
+ // main entry point. Perform the optimization loop, starting from
108
+ // and modifying permutation in-place
109
+ double optimize (int *perm);
110
+
111
+ // run the optimization and return the best result in best_perm
112
+ double run_optimization (int * best_perm);
113
+
114
+ virtual ~SimulatedAnnealingOptimizer ();
115
+ };
116
+
117
+
118
+
119
+
120
+ /// optimizes the order of indices in a ProductQuantizer
121
+ struct PolysemousTraining: SimulatedAnnealingParameters {
122
+
123
+ enum Optimization_type_t {
124
+ OT_None,
125
+ OT_ReproduceDistances_affine, ///< default
126
+ OT_Ranking_weighted_diff /// same as _2, but use rank of y+ - rank of y-
127
+ };
128
+ Optimization_type_t optimization_type;
129
+
130
+ // use 1/4 of the training points for the optimization, with
131
+ // max. ntrain_permutation. If ntrain_permutation == 0: train on
132
+ // centroids
133
+ int ntrain_permutation;
134
+ double dis_weight_factor; // decay of exp that weights distance loss
135
+
136
+ // filename pattern for the logging of iterations
137
+ std::string log_pattern;
138
+
139
+ // sets default values
140
+ PolysemousTraining ();
141
+
142
+ /// reorder the centroids so that the Hamming distace becomes a
143
+ /// good approximation of the SDC distance (called by train)
144
+ void optimize_pq_for_hamming (ProductQuantizer & pq,
145
+ size_t n, const float *x) const;
146
+
147
+ /// called by optimize_pq_for_hamming
148
+ void optimize_ranking (ProductQuantizer &pq, size_t n, const float *x) const;
149
+ /// called by optimize_pq_for_hamming
150
+ void optimize_reproduce_distances (ProductQuantizer &pq) const;
151
+
152
+ };
153
+
154
+
155
+ } // namespace faiss
156
+
157
+
158
+ #endif
@@ -0,0 +1,876 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/ProductQuantizer.h>
11
+
12
+
13
+ #include <cstddef>
14
+ #include <cstring>
15
+ #include <cstdio>
16
+ #include <memory>
17
+
18
+ #include <algorithm>
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/VectorTransform.h>
22
+ #include <faiss/IndexFlat.h>
23
+ #include <faiss/utils/distances.h>
24
+
25
+
26
+ extern "C" {
27
+
28
+ /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
29
+
30
+ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
31
+ n, FINTEGER *k, const float *alpha, const float *a,
32
+ FINTEGER *lda, const float *b, FINTEGER *
33
+ ldb, float *beta, float *c, FINTEGER *ldc);
34
+
35
+ }
36
+
37
+
38
+ namespace faiss {
39
+
40
+
41
+ /* compute an estimator using look-up tables for typical values of M */
42
+ template <typename CT, class C>
43
+ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
44
+ size_t ncodes,
45
+ const float * __restrict dis_table,
46
+ size_t ksub,
47
+ size_t k,
48
+ float * heap_dis,
49
+ int64_t * heap_ids)
50
+ {
51
+
52
+ for (size_t j = 0; j < ncodes; j++) {
53
+ float dis = 0;
54
+ const float *dt = dis_table;
55
+
56
+ for (size_t m = 0; m < M; m+=4) {
57
+ float dism = 0;
58
+ dism = dt[*codes++]; dt += ksub;
59
+ dism += dt[*codes++]; dt += ksub;
60
+ dism += dt[*codes++]; dt += ksub;
61
+ dism += dt[*codes++]; dt += ksub;
62
+ dis += dism;
63
+ }
64
+
65
+ if (C::cmp (heap_dis[0], dis)) {
66
+ heap_pop<C> (k, heap_dis, heap_ids);
67
+ heap_push<C> (k, heap_dis, heap_ids, dis, j);
68
+ }
69
+ }
70
+ }
71
+
72
+
73
+ template <typename CT, class C>
74
+ void pq_estimators_from_tables_M4 (const CT * codes,
75
+ size_t ncodes,
76
+ const float * __restrict dis_table,
77
+ size_t ksub,
78
+ size_t k,
79
+ float * heap_dis,
80
+ int64_t * heap_ids)
81
+ {
82
+
83
+ for (size_t j = 0; j < ncodes; j++) {
84
+ float dis = 0;
85
+ const float *dt = dis_table;
86
+ dis = dt[*codes++]; dt += ksub;
87
+ dis += dt[*codes++]; dt += ksub;
88
+ dis += dt[*codes++]; dt += ksub;
89
+ dis += dt[*codes++];
90
+
91
+ if (C::cmp (heap_dis[0], dis)) {
92
+ heap_pop<C> (k, heap_dis, heap_ids);
93
+ heap_push<C> (k, heap_dis, heap_ids, dis, j);
94
+ }
95
+ }
96
+ }
97
+
98
+
99
+ template <typename CT, class C>
100
+ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
101
+ const CT * codes,
102
+ size_t ncodes,
103
+ const float * dis_table,
104
+ size_t k,
105
+ float * heap_dis,
106
+ int64_t * heap_ids)
107
+ {
108
+
109
+ if (pq.M == 4) {
110
+
111
+ pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
112
+ dis_table, pq.ksub, k,
113
+ heap_dis, heap_ids);
114
+ return;
115
+ }
116
+
117
+ if (pq.M % 4 == 0) {
118
+ pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
119
+ dis_table, pq.ksub, k,
120
+ heap_dis, heap_ids);
121
+ return;
122
+ }
123
+
124
+ /* Default is relatively slow */
125
+ const size_t M = pq.M;
126
+ const size_t ksub = pq.ksub;
127
+ for (size_t j = 0; j < ncodes; j++) {
128
+ float dis = 0;
129
+ const float * __restrict dt = dis_table;
130
+ for (int m = 0; m < M; m++) {
131
+ dis += dt[*codes++];
132
+ dt += ksub;
133
+ }
134
+ if (C::cmp (heap_dis[0], dis)) {
135
+ heap_pop<C> (k, heap_dis, heap_ids);
136
+ heap_push<C> (k, heap_dis, heap_ids, dis, j);
137
+ }
138
+ }
139
+ }
140
+
141
+ template <class C>
142
+ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
143
+ size_t nbits,
144
+ const uint8_t *codes,
145
+ size_t ncodes,
146
+ const float *dis_table,
147
+ size_t k,
148
+ float *heap_dis,
149
+ int64_t *heap_ids)
150
+ {
151
+ const size_t M = pq.M;
152
+ const size_t ksub = pq.ksub;
153
+ for (size_t j = 0; j < ncodes; ++j) {
154
+ faiss::ProductQuantizer::PQDecoderGeneric decoder(
155
+ codes + j * pq.code_size, nbits
156
+ );
157
+ float dis = 0;
158
+ const float * __restrict dt = dis_table;
159
+ for (size_t m = 0; m < M; m++) {
160
+ uint64_t c = decoder.decode();
161
+ dis += dt[c];
162
+ dt += ksub;
163
+ }
164
+
165
+ if (C::cmp(heap_dis[0], dis)) {
166
+ heap_pop<C>(k, heap_dis, heap_ids);
167
+ heap_push<C>(k, heap_dis, heap_ids, dis, j);
168
+ }
169
+ }
170
+ }
171
+
172
+ /*********************************************
173
+ * PQ implementation
174
+ *********************************************/
175
+
176
+
177
+
178
+ ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
179
+ d(d), M(M), nbits(nbits), assign_index(nullptr)
180
+ {
181
+ set_derived_values ();
182
+ }
183
+
184
+ ProductQuantizer::ProductQuantizer ()
185
+ : ProductQuantizer(0, 1, 0) {}
186
+
187
+ void ProductQuantizer::set_derived_values () {
188
+ // quite a few derived values
189
+ FAISS_THROW_IF_NOT (d % M == 0);
190
+ dsub = d / M;
191
+ code_size = (nbits * M + 7) / 8;
192
+ ksub = 1 << nbits;
193
+ centroids.resize (d * ksub);
194
+ verbose = false;
195
+ train_type = Train_default;
196
+ }
197
+
198
+ void ProductQuantizer::set_params (const float * centroids_, int m)
199
+ {
200
+ memcpy (get_centroids(m, 0), centroids_,
201
+ ksub * dsub * sizeof (centroids_[0]));
202
+ }
203
+
204
+
205
+ static void init_hypercube (int d, int nbits,
206
+ int n, const float * x,
207
+ float *centroids)
208
+ {
209
+
210
+ std::vector<float> mean (d);
211
+ for (int i = 0; i < n; i++)
212
+ for (int j = 0; j < d; j++)
213
+ mean [j] += x[i * d + j];
214
+
215
+ float maxm = 0;
216
+ for (int j = 0; j < d; j++) {
217
+ mean [j] /= n;
218
+ if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
219
+ }
220
+
221
+ for (int i = 0; i < (1 << nbits); i++) {
222
+ float * cent = centroids + i * d;
223
+ for (int j = 0; j < nbits; j++)
224
+ cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
225
+ for (int j = nbits; j < d; j++)
226
+ cent[j] = mean [j];
227
+ }
228
+
229
+
230
+ }
231
+
232
+ static void init_hypercube_pca (int d, int nbits,
233
+ int n, const float * x,
234
+ float *centroids)
235
+ {
236
+ PCAMatrix pca (d, nbits);
237
+ pca.train (n, x);
238
+
239
+
240
+ for (int i = 0; i < (1 << nbits); i++) {
241
+ float * cent = centroids + i * d;
242
+ for (int j = 0; j < d; j++) {
243
+ cent[j] = pca.mean[j];
244
+ float f = 1.0;
245
+ for (int k = 0; k < nbits; k++)
246
+ cent[j] += f *
247
+ sqrt (pca.eigenvalues [k]) *
248
+ (((i >> k) & 1) ? 1 : -1) *
249
+ pca.PCAMat [j + k * d];
250
+ }
251
+ }
252
+
253
+ }
254
+
255
+ void ProductQuantizer::train (int n, const float * x)
256
+ {
257
+ if (train_type != Train_shared) {
258
+ train_type_t final_train_type;
259
+ final_train_type = train_type;
260
+ if (train_type == Train_hypercube ||
261
+ train_type == Train_hypercube_pca) {
262
+ if (dsub < nbits) {
263
+ final_train_type = Train_default;
264
+ printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
265
+ nbits, dsub);
266
+ }
267
+ }
268
+
269
+ float * xslice = new float[n * dsub];
270
+ ScopeDeleter<float> del (xslice);
271
+ for (int m = 0; m < M; m++) {
272
+ for (int j = 0; j < n; j++)
273
+ memcpy (xslice + j * dsub,
274
+ x + j * d + m * dsub,
275
+ dsub * sizeof(float));
276
+
277
+ Clustering clus (dsub, ksub, cp);
278
+
279
+ // we have some initialization for the centroids
280
+ if (final_train_type != Train_default) {
281
+ clus.centroids.resize (dsub * ksub);
282
+ }
283
+
284
+ switch (final_train_type) {
285
+ case Train_hypercube:
286
+ init_hypercube (dsub, nbits, n, xslice,
287
+ clus.centroids.data ());
288
+ break;
289
+ case Train_hypercube_pca:
290
+ init_hypercube_pca (dsub, nbits, n, xslice,
291
+ clus.centroids.data ());
292
+ break;
293
+ case Train_hot_start:
294
+ memcpy (clus.centroids.data(),
295
+ get_centroids (m, 0),
296
+ dsub * ksub * sizeof (float));
297
+ break;
298
+ default: ;
299
+ }
300
+
301
+ if(verbose) {
302
+ clus.verbose = true;
303
+ printf ("Training PQ slice %d/%zd\n", m, M);
304
+ }
305
+ IndexFlatL2 index (dsub);
306
+ clus.train (n, xslice, assign_index ? *assign_index : index);
307
+ set_params (clus.centroids.data(), m);
308
+ }
309
+
310
+
311
+ } else {
312
+
313
+ Clustering clus (dsub, ksub, cp);
314
+
315
+ if(verbose) {
316
+ clus.verbose = true;
317
+ printf ("Training all PQ slices at once\n");
318
+ }
319
+
320
+ IndexFlatL2 index (dsub);
321
+
322
+ clus.train (n * M, x, assign_index ? *assign_index : index);
323
+ for (int m = 0; m < M; m++) {
324
+ set_params (clus.centroids.data(), m);
325
+ }
326
+
327
+ }
328
+ }
329
+
330
+ template<class PQEncoder>
331
+ void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) {
332
+ float distances [pq.ksub];
333
+ PQEncoder encoder(code, pq.nbits);
334
+ for (size_t m = 0; m < pq.M; m++) {
335
+ float mindis = 1e20;
336
+ uint64_t idxm = 0;
337
+ const float * xsub = x + m * pq.dsub;
338
+
339
+ fvec_L2sqr_ny(distances, xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
340
+
341
+ /* Find best centroid */
342
+ for (size_t i = 0; i < pq.ksub; i++) {
343
+ float dis = distances[i];
344
+ if (dis < mindis) {
345
+ mindis = dis;
346
+ idxm = i;
347
+ }
348
+ }
349
+
350
+ encoder.encode(idxm);
351
+ }
352
+ }
353
+
354
+ void ProductQuantizer::compute_code(const float * x, uint8_t * code) const {
355
+ switch (nbits) {
356
+ case 8:
357
+ faiss::compute_code<PQEncoder8>(*this, x, code);
358
+ break;
359
+
360
+ case 16:
361
+ faiss::compute_code<PQEncoder16>(*this, x, code);
362
+ break;
363
+
364
+ default:
365
+ faiss::compute_code<PQEncoderGeneric>(*this, x, code);
366
+ break;
367
+ }
368
+ }
369
+
370
+ template<class PQDecoder>
371
+ void decode(const ProductQuantizer& pq, const uint8_t *code, float *x)
372
+ {
373
+ PQDecoder decoder(code, pq.nbits);
374
+ for (size_t m = 0; m < pq.M; m++) {
375
+ uint64_t c = decoder.decode();
376
+ memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub);
377
+ }
378
+ }
379
+
380
+ void ProductQuantizer::decode (const uint8_t *code, float *x) const
381
+ {
382
+ switch (nbits) {
383
+ case 8:
384
+ faiss::decode<PQDecoder8>(*this, code, x);
385
+ break;
386
+
387
+ case 16:
388
+ faiss::decode<PQDecoder16>(*this, code, x);
389
+ break;
390
+
391
+ default:
392
+ faiss::decode<PQDecoderGeneric>(*this, code, x);
393
+ break;
394
+ }
395
+ }
396
+
397
+
398
+ void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
399
+ {
400
+ for (size_t i = 0; i < n; i++) {
401
+ this->decode (code + code_size * i, x + d * i);
402
+ }
403
+ }
404
+
405
+
406
+ void ProductQuantizer::compute_code_from_distance_table (const float *tab,
407
+ uint8_t *code) const
408
+ {
409
+ PQEncoderGeneric encoder(code, nbits);
410
+ for (size_t m = 0; m < M; m++) {
411
+ float mindis = 1e20;
412
+ uint64_t idxm = 0;
413
+
414
+ /* Find best centroid */
415
+ for (size_t j = 0; j < ksub; j++) {
416
+ float dis = *tab++;
417
+ if (dis < mindis) {
418
+ mindis = dis;
419
+ idxm = j;
420
+ }
421
+ }
422
+
423
+ encoder.encode(idxm);
424
+ }
425
+ }
426
+
427
+ void ProductQuantizer::compute_codes_with_assign_index (
428
+ const float * x,
429
+ uint8_t * codes,
430
+ size_t n)
431
+ {
432
+ FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
433
+
434
+ for (size_t m = 0; m < M; m++) {
435
+ assign_index->reset ();
436
+ assign_index->add (ksub, get_centroids (m, 0));
437
+ size_t bs = 65536;
438
+ float * xslice = new float[bs * dsub];
439
+ ScopeDeleter<float> del (xslice);
440
+ idx_t *assign = new idx_t[bs];
441
+ ScopeDeleter<idx_t> del2 (assign);
442
+
443
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
444
+ size_t i1 = std::min(i0 + bs, n);
445
+
446
+ for (size_t i = i0; i < i1; i++) {
447
+ memcpy (xslice + (i - i0) * dsub,
448
+ x + i * d + m * dsub,
449
+ dsub * sizeof(float));
450
+ }
451
+
452
+ assign_index->assign (i1 - i0, xslice, assign);
453
+
454
+ if (nbits == 8) {
455
+ uint8_t *c = codes + code_size * i0 + m;
456
+ for (size_t i = i0; i < i1; i++) {
457
+ *c = assign[i - i0];
458
+ c += M;
459
+ }
460
+ } else if (nbits == 16) {
461
+ uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2);
462
+ for (size_t i = i0; i < i1; i++) {
463
+ *c = assign[i - i0];
464
+ c += M;
465
+ }
466
+ } else {
467
+ for (size_t i = i0; i < i1; ++i) {
468
+ uint8_t *c = codes + code_size * i + ((m * nbits) / 8);
469
+ uint8_t offset = (m * nbits) % 8;
470
+ uint64_t ass = assign[i - i0];
471
+
472
+ PQEncoderGeneric encoder(c, nbits, offset);
473
+ encoder.encode(ass);
474
+ }
475
+ }
476
+
477
+ }
478
+ }
479
+
480
+ }
481
+
482
+ void ProductQuantizer::compute_codes (const float * x,
483
+ uint8_t * codes,
484
+ size_t n) const
485
+ {
486
+ // process by blocks to avoid using too much RAM
487
+ size_t bs = 256 * 1024;
488
+ if (n > bs) {
489
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
490
+ size_t i1 = std::min(i0 + bs, n);
491
+ compute_codes (x + d * i0, codes + code_size * i0, i1 - i0);
492
+ }
493
+ return;
494
+ }
495
+
496
+ if (dsub < 16) { // simple direct computation
497
+
498
+ #pragma omp parallel for
499
+ for (size_t i = 0; i < n; i++)
500
+ compute_code (x + i * d, codes + i * code_size);
501
+
502
+ } else { // worthwile to use BLAS
503
+ float *dis_tables = new float [n * ksub * M];
504
+ ScopeDeleter<float> del (dis_tables);
505
+ compute_distance_tables (n, x, dis_tables);
506
+
507
+ #pragma omp parallel for
508
+ for (size_t i = 0; i < n; i++) {
509
+ uint8_t * code = codes + i * code_size;
510
+ const float * tab = dis_tables + i * ksub * M;
511
+ compute_code_from_distance_table (tab, code);
512
+ }
513
+ }
514
+ }
515
+
516
+
517
+ void ProductQuantizer::compute_distance_table (const float * x,
518
+ float * dis_table) const
519
+ {
520
+ size_t m;
521
+
522
+ for (m = 0; m < M; m++) {
523
+ fvec_L2sqr_ny (dis_table + m * ksub,
524
+ x + m * dsub,
525
+ get_centroids(m, 0),
526
+ dsub,
527
+ ksub);
528
+ }
529
+ }
530
+
531
+ void ProductQuantizer::compute_inner_prod_table (const float * x,
532
+ float * dis_table) const
533
+ {
534
+ size_t m;
535
+
536
+ for (m = 0; m < M; m++) {
537
+ fvec_inner_products_ny (dis_table + m * ksub,
538
+ x + m * dsub,
539
+ get_centroids(m, 0),
540
+ dsub,
541
+ ksub);
542
+ }
543
+ }
544
+
545
+
546
+ void ProductQuantizer::compute_distance_tables (
547
+ size_t nx,
548
+ const float * x,
549
+ float * dis_tables) const
550
+ {
551
+
552
+ if (dsub < 16) {
553
+
554
+ #pragma omp parallel for
555
+ for (size_t i = 0; i < nx; i++) {
556
+ compute_distance_table (x + i * d, dis_tables + i * ksub * M);
557
+ }
558
+
559
+ } else { // use BLAS
560
+
561
+ for (int m = 0; m < M; m++) {
562
+ pairwise_L2sqr (dsub,
563
+ nx, x + dsub * m,
564
+ ksub, centroids.data() + m * dsub * ksub,
565
+ dis_tables + ksub * m,
566
+ d, dsub, ksub * M);
567
+ }
568
+ }
569
+ }
570
+
571
+ void ProductQuantizer::compute_inner_prod_tables (
572
+ size_t nx,
573
+ const float * x,
574
+ float * dis_tables) const
575
+ {
576
+
577
+ if (dsub < 16) {
578
+
579
+ #pragma omp parallel for
580
+ for (size_t i = 0; i < nx; i++) {
581
+ compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
582
+ }
583
+
584
+ } else { // use BLAS
585
+
586
+ // compute distance tables
587
+ for (int m = 0; m < M; m++) {
588
+ FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
589
+ dsubi = dsub, di = d;
590
+ float one = 1.0, zero = 0;
591
+
592
+ sgemm_ ("Transposed", "Not transposed",
593
+ &ksubi, &nxi, &dsubi,
594
+ &one, &centroids [m * dsub * ksub], &dsubi,
595
+ x + dsub * m, &di,
596
+ &zero, dis_tables + ksub * m, &ldc);
597
+ }
598
+
599
+ }
600
+ }
601
+
602
+ template <class C>
603
+ static void pq_knn_search_with_tables (
604
+ const ProductQuantizer& pq,
605
+ size_t nbits,
606
+ const float *dis_tables,
607
+ const uint8_t * codes,
608
+ const size_t ncodes,
609
+ HeapArray<C> * res,
610
+ bool init_finalize_heap)
611
+ {
612
+ size_t k = res->k, nx = res->nh;
613
+ size_t ksub = pq.ksub, M = pq.M;
614
+
615
+
616
+ #pragma omp parallel for
617
+ for (size_t i = 0; i < nx; i++) {
618
+ /* query preparation for asymmetric search: compute look-up tables */
619
+ const float* dis_table = dis_tables + i * ksub * M;
620
+
621
+ /* Compute distances and keep smallest values */
622
+ int64_t * __restrict heap_ids = res->ids + i * k;
623
+ float * __restrict heap_dis = res->val + i * k;
624
+
625
+ if (init_finalize_heap) {
626
+ heap_heapify<C> (k, heap_dis, heap_ids);
627
+ }
628
+
629
+ switch (nbits) {
630
+ case 8:
631
+ pq_estimators_from_tables<uint8_t, C> (pq,
632
+ codes, ncodes,
633
+ dis_table,
634
+ k, heap_dis, heap_ids);
635
+ break;
636
+
637
+ case 16:
638
+ pq_estimators_from_tables<uint16_t, C> (pq,
639
+ (uint16_t*)codes, ncodes,
640
+ dis_table,
641
+ k, heap_dis, heap_ids);
642
+ break;
643
+
644
+ default:
645
+ pq_estimators_from_tables_generic<C> (pq,
646
+ nbits,
647
+ codes, ncodes,
648
+ dis_table,
649
+ k, heap_dis, heap_ids);
650
+ break;
651
+ }
652
+
653
+ if (init_finalize_heap) {
654
+ heap_reorder<C> (k, heap_dis, heap_ids);
655
+ }
656
+ }
657
+ }
658
+
659
+ void ProductQuantizer::search (const float * __restrict x,
660
+ size_t nx,
661
+ const uint8_t * codes,
662
+ const size_t ncodes,
663
+ float_maxheap_array_t * res,
664
+ bool init_finalize_heap) const
665
+ {
666
+ FAISS_THROW_IF_NOT (nx == res->nh);
667
+ std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
668
+ compute_distance_tables (nx, x, dis_tables.get());
669
+
670
+ pq_knn_search_with_tables<CMax<float, int64_t>> (
671
+ *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
672
+ }
673
+
674
+ void ProductQuantizer::search_ip (const float * __restrict x,
675
+ size_t nx,
676
+ const uint8_t * codes,
677
+ const size_t ncodes,
678
+ float_minheap_array_t * res,
679
+ bool init_finalize_heap) const
680
+ {
681
+ FAISS_THROW_IF_NOT (nx == res->nh);
682
+ std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
683
+ compute_inner_prod_tables (nx, x, dis_tables.get());
684
+
685
+ pq_knn_search_with_tables<CMin<float, int64_t> > (
686
+ *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
687
+ }
688
+
689
+
690
+
691
+ static float sqr (float x) {
692
+ return x * x;
693
+ }
694
+
695
+ void ProductQuantizer::compute_sdc_table ()
696
+ {
697
+ sdc_table.resize (M * ksub * ksub);
698
+
699
+ for (int m = 0; m < M; m++) {
700
+
701
+ const float *cents = centroids.data() + m * ksub * dsub;
702
+ float * dis_tab = sdc_table.data() + m * ksub * ksub;
703
+
704
+ // TODO optimize with BLAS
705
+ for (int i = 0; i < ksub; i++) {
706
+ const float *centi = cents + i * dsub;
707
+ for (int j = 0; j < ksub; j++) {
708
+ float accu = 0;
709
+ const float *centj = cents + j * dsub;
710
+ for (int k = 0; k < dsub; k++)
711
+ accu += sqr (centi[k] - centj[k]);
712
+ dis_tab [i + j * ksub] = accu;
713
+ }
714
+ }
715
+ }
716
+ }
717
+
718
+ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
719
+ size_t nq,
720
+ const uint8_t * bcodes,
721
+ const size_t nb,
722
+ float_maxheap_array_t * res,
723
+ bool init_finalize_heap) const
724
+ {
725
+ FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
726
+ FAISS_THROW_IF_NOT (nbits == 8);
727
+ size_t k = res->k;
728
+
729
+
730
+ #pragma omp parallel for
731
+ for (size_t i = 0; i < nq; i++) {
732
+
733
+ /* Compute distances and keep smallest values */
734
+ idx_t * heap_ids = res->ids + i * k;
735
+ float * heap_dis = res->val + i * k;
736
+ const uint8_t * qcode = qcodes + i * code_size;
737
+
738
+ if (init_finalize_heap)
739
+ maxheap_heapify (k, heap_dis, heap_ids);
740
+
741
+ const uint8_t * bcode = bcodes;
742
+ for (size_t j = 0; j < nb; j++) {
743
+ float dis = 0;
744
+ const float * tab = sdc_table.data();
745
+ for (int m = 0; m < M; m++) {
746
+ dis += tab[bcode[m] + qcode[m] * ksub];
747
+ tab += ksub * ksub;
748
+ }
749
+ if (dis < heap_dis[0]) {
750
+ maxheap_pop (k, heap_dis, heap_ids);
751
+ maxheap_push (k, heap_dis, heap_ids, dis, j);
752
+ }
753
+ bcode += code_size;
754
+ }
755
+
756
+ if (init_finalize_heap)
757
+ maxheap_reorder (k, heap_dis, heap_ids);
758
+ }
759
+
760
+ }
761
+
762
+
763
+ ProductQuantizer::PQEncoderGeneric::PQEncoderGeneric(uint8_t *code, int nbits,
764
+ uint8_t offset)
765
+ : code(code), offset(offset), nbits(nbits), reg(0) {
766
+ assert(nbits <= 64);
767
+ if (offset > 0) {
768
+ reg = (*code & ((1 << offset) - 1));
769
+ }
770
+ }
771
+
772
+ void ProductQuantizer::PQEncoderGeneric::encode(uint64_t x) {
773
+ reg |= (uint8_t)(x << offset);
774
+ x >>= (8 - offset);
775
+ if (offset + nbits >= 8) {
776
+ *code++ = reg;
777
+
778
+ for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
779
+ *code++ = (uint8_t)x;
780
+ x >>= 8;
781
+ }
782
+
783
+ offset += nbits;
784
+ offset &= 7;
785
+ reg = (uint8_t)x;
786
+ } else {
787
+ offset += nbits;
788
+ }
789
+ }
790
+
791
+ ProductQuantizer::PQEncoderGeneric::~PQEncoderGeneric() {
792
+ if (offset > 0) {
793
+ *code = reg;
794
+ }
795
+ }
796
+
797
+
798
+ ProductQuantizer::PQEncoder8::PQEncoder8(uint8_t *code, int nbits)
799
+ : code(code) {
800
+ assert(8 == nbits);
801
+ }
802
+
803
+ void ProductQuantizer::PQEncoder8::encode(uint64_t x) {
804
+ *code++ = (uint8_t)x;
805
+ }
806
+
807
+
808
+ ProductQuantizer::PQEncoder16::PQEncoder16(uint8_t *code, int nbits)
809
+ : code((uint16_t *)code) {
810
+ assert(16 == nbits);
811
+ }
812
+
813
+ void ProductQuantizer::PQEncoder16::encode(uint64_t x) {
814
+ *code++ = (uint16_t)x;
815
+ }
816
+
817
+
818
+ ProductQuantizer::PQDecoderGeneric::PQDecoderGeneric(const uint8_t *code,
819
+ int nbits)
820
+ : code(code),
821
+ offset(0),
822
+ nbits(nbits),
823
+ mask((1ull << nbits) - 1),
824
+ reg(0) {
825
+ assert(nbits <= 64);
826
+ }
827
+
828
+ uint64_t ProductQuantizer::PQDecoderGeneric::decode() {
829
+ if (offset == 0) {
830
+ reg = *code;
831
+ }
832
+ uint64_t c = (reg >> offset);
833
+
834
+ if (offset + nbits >= 8) {
835
+ uint64_t e = 8 - offset;
836
+ ++code;
837
+ for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
838
+ c |= ((uint64_t)(*code++) << e);
839
+ e += 8;
840
+ }
841
+
842
+ offset += nbits;
843
+ offset &= 7;
844
+ if (offset > 0) {
845
+ reg = *code;
846
+ c |= ((uint64_t)reg << e);
847
+ }
848
+ } else {
849
+ offset += nbits;
850
+ }
851
+
852
+ return c & mask;
853
+ }
854
+
855
+
856
+ ProductQuantizer::PQDecoder8::PQDecoder8(const uint8_t *code, int nbits)
857
+ : code(code) {
858
+ assert(8 == nbits);
859
+ }
860
+
861
+ uint64_t ProductQuantizer::PQDecoder8::decode() {
862
+ return (uint64_t)(*code++);
863
+ }
864
+
865
+
866
+ ProductQuantizer::PQDecoder16::PQDecoder16(const uint8_t *code, int nbits)
867
+ : code((uint16_t *)code) {
868
+ assert(16 == nbits);
869
+ }
870
+
871
+ uint64_t ProductQuantizer::PQDecoder16::decode() {
872
+ return (uint64_t)(*code++);
873
+ }
874
+
875
+
876
+ } // namespace faiss