faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,158 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED
11
+ #define FAISS_POLYSEMOUS_TRAINING_INCLUDED
12
+
13
+
14
+ #include <faiss/impl/ProductQuantizer.h>
15
+
16
+
17
+ namespace faiss {
18
+
19
+
20
+ /// parameters used for the simulated annealing method
21
+ struct SimulatedAnnealingParameters {
22
+
23
+ // optimization parameters
24
+ double init_temperature; // init probaility of accepting a bad swap
25
+ double temperature_decay; // at each iteration the temp is multiplied by this
26
+ int n_iter; // nb of iterations
27
+ int n_redo; // nb of runs of the simulation
28
+ int seed; // random seed
29
+ int verbose;
30
+ bool only_bit_flips; // restrict permutation changes to bit flips
31
+ bool init_random; // intialize with a random permutation (not identity)
32
+
33
+ // set reasonable defaults
34
+ SimulatedAnnealingParameters ();
35
+
36
+ };
37
+
38
+
39
+ /// abstract class for the loss function
40
+ struct PermutationObjective {
41
+
42
+ int n;
43
+
44
+ virtual double compute_cost (const int *perm) const = 0;
45
+
46
+ // what would the cost update be if iw and jw were swapped?
47
+ // default implementation just computes both and computes the difference
48
+ virtual double cost_update (const int *perm, int iw, int jw) const;
49
+
50
+ virtual ~PermutationObjective () {}
51
+ };
52
+
53
+
54
+ struct ReproduceDistancesObjective : PermutationObjective {
55
+
56
+ double dis_weight_factor;
57
+
58
+ static double sqr (double x) { return x * x; }
59
+
60
+ // weihgting of distances: it is more important to reproduce small
61
+ // distances well
62
+ double dis_weight (double x) const;
63
+
64
+ std::vector<double> source_dis; ///< "real" corrected distances (size n^2)
65
+ const double * target_dis; ///< wanted distances (size n^2)
66
+ std::vector<double> weights; ///< weights for each distance (size n^2)
67
+
68
+ double get_source_dis (int i, int j) const;
69
+
70
+ // cost = quadratic difference between actual distance and Hamming distance
71
+ double compute_cost(const int* perm) const override;
72
+
73
+ // what would the cost update be if iw and jw were swapped?
74
+ // computed in O(n) instead of O(n^2) for the full re-computation
75
+ double cost_update(const int* perm, int iw, int jw) const override;
76
+
77
+ ReproduceDistancesObjective (
78
+ int n,
79
+ const double *source_dis_in,
80
+ const double *target_dis_in,
81
+ double dis_weight_factor);
82
+
83
+ static void compute_mean_stdev (const double *tab, size_t n2,
84
+ double *mean_out, double *stddev_out);
85
+
86
+ void set_affine_target_dis (const double *source_dis_in);
87
+
88
+ ~ReproduceDistancesObjective() override {}
89
+ };
90
+
91
+ struct RandomGenerator;
92
+
93
+ /// Simulated annealing optimization algorithm for permutations.
94
+ struct SimulatedAnnealingOptimizer: SimulatedAnnealingParameters {
95
+
96
+ PermutationObjective *obj;
97
+ int n; ///< size of the permutation
98
+ FILE *logfile; /// logs values of the cost function
99
+
100
+ SimulatedAnnealingOptimizer (PermutationObjective *obj,
101
+ const SimulatedAnnealingParameters &p);
102
+ RandomGenerator *rnd;
103
+
104
+ /// remember intial cost of optimization
105
+ double init_cost;
106
+
107
+ // main entry point. Perform the optimization loop, starting from
108
+ // and modifying permutation in-place
109
+ double optimize (int *perm);
110
+
111
+ // run the optimization and return the best result in best_perm
112
+ double run_optimization (int * best_perm);
113
+
114
+ virtual ~SimulatedAnnealingOptimizer ();
115
+ };
116
+
117
+
118
+
119
+
120
+ /// optimizes the order of indices in a ProductQuantizer
121
+ struct PolysemousTraining: SimulatedAnnealingParameters {
122
+
123
+ enum Optimization_type_t {
124
+ OT_None,
125
+ OT_ReproduceDistances_affine, ///< default
126
+ OT_Ranking_weighted_diff /// same as _2, but use rank of y+ - rank of y-
127
+ };
128
+ Optimization_type_t optimization_type;
129
+
130
+ // use 1/4 of the training points for the optimization, with
131
+ // max. ntrain_permutation. If ntrain_permutation == 0: train on
132
+ // centroids
133
+ int ntrain_permutation;
134
+ double dis_weight_factor; // decay of exp that weights distance loss
135
+
136
+ // filename pattern for the logging of iterations
137
+ std::string log_pattern;
138
+
139
+ // sets default values
140
+ PolysemousTraining ();
141
+
142
+ /// reorder the centroids so that the Hamming distace becomes a
143
+ /// good approximation of the SDC distance (called by train)
144
+ void optimize_pq_for_hamming (ProductQuantizer & pq,
145
+ size_t n, const float *x) const;
146
+
147
+ /// called by optimize_pq_for_hamming
148
+ void optimize_ranking (ProductQuantizer &pq, size_t n, const float *x) const;
149
+ /// called by optimize_pq_for_hamming
150
+ void optimize_reproduce_distances (ProductQuantizer &pq) const;
151
+
152
+ };
153
+
154
+
155
+ } // namespace faiss
156
+
157
+
158
+ #endif
@@ -0,0 +1,876 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/ProductQuantizer.h>
11
+
12
+
13
+ #include <cstddef>
14
+ #include <cstring>
15
+ #include <cstdio>
16
+ #include <memory>
17
+
18
+ #include <algorithm>
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/VectorTransform.h>
22
+ #include <faiss/IndexFlat.h>
23
+ #include <faiss/utils/distances.h>
24
+
25
+
26
+ extern "C" {
27
+
28
+ /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
29
+
30
+ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
31
+ n, FINTEGER *k, const float *alpha, const float *a,
32
+ FINTEGER *lda, const float *b, FINTEGER *
33
+ ldb, float *beta, float *c, FINTEGER *ldc);
34
+
35
+ }
36
+
37
+
38
+ namespace faiss {
39
+
40
+
41
+ /* compute an estimator using look-up tables for typical values of M */
42
+ template <typename CT, class C>
43
+ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
44
+ size_t ncodes,
45
+ const float * __restrict dis_table,
46
+ size_t ksub,
47
+ size_t k,
48
+ float * heap_dis,
49
+ int64_t * heap_ids)
50
+ {
51
+
52
+ for (size_t j = 0; j < ncodes; j++) {
53
+ float dis = 0;
54
+ const float *dt = dis_table;
55
+
56
+ for (size_t m = 0; m < M; m+=4) {
57
+ float dism = 0;
58
+ dism = dt[*codes++]; dt += ksub;
59
+ dism += dt[*codes++]; dt += ksub;
60
+ dism += dt[*codes++]; dt += ksub;
61
+ dism += dt[*codes++]; dt += ksub;
62
+ dis += dism;
63
+ }
64
+
65
+ if (C::cmp (heap_dis[0], dis)) {
66
+ heap_pop<C> (k, heap_dis, heap_ids);
67
+ heap_push<C> (k, heap_dis, heap_ids, dis, j);
68
+ }
69
+ }
70
+ }
71
+
72
+
73
+ template <typename CT, class C>
74
+ void pq_estimators_from_tables_M4 (const CT * codes,
75
+ size_t ncodes,
76
+ const float * __restrict dis_table,
77
+ size_t ksub,
78
+ size_t k,
79
+ float * heap_dis,
80
+ int64_t * heap_ids)
81
+ {
82
+
83
+ for (size_t j = 0; j < ncodes; j++) {
84
+ float dis = 0;
85
+ const float *dt = dis_table;
86
+ dis = dt[*codes++]; dt += ksub;
87
+ dis += dt[*codes++]; dt += ksub;
88
+ dis += dt[*codes++]; dt += ksub;
89
+ dis += dt[*codes++];
90
+
91
+ if (C::cmp (heap_dis[0], dis)) {
92
+ heap_pop<C> (k, heap_dis, heap_ids);
93
+ heap_push<C> (k, heap_dis, heap_ids, dis, j);
94
+ }
95
+ }
96
+ }
97
+
98
+
99
+ template <typename CT, class C>
100
+ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
101
+ const CT * codes,
102
+ size_t ncodes,
103
+ const float * dis_table,
104
+ size_t k,
105
+ float * heap_dis,
106
+ int64_t * heap_ids)
107
+ {
108
+
109
+ if (pq.M == 4) {
110
+
111
+ pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
112
+ dis_table, pq.ksub, k,
113
+ heap_dis, heap_ids);
114
+ return;
115
+ }
116
+
117
+ if (pq.M % 4 == 0) {
118
+ pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
119
+ dis_table, pq.ksub, k,
120
+ heap_dis, heap_ids);
121
+ return;
122
+ }
123
+
124
+ /* Default is relatively slow */
125
+ const size_t M = pq.M;
126
+ const size_t ksub = pq.ksub;
127
+ for (size_t j = 0; j < ncodes; j++) {
128
+ float dis = 0;
129
+ const float * __restrict dt = dis_table;
130
+ for (int m = 0; m < M; m++) {
131
+ dis += dt[*codes++];
132
+ dt += ksub;
133
+ }
134
+ if (C::cmp (heap_dis[0], dis)) {
135
+ heap_pop<C> (k, heap_dis, heap_ids);
136
+ heap_push<C> (k, heap_dis, heap_ids, dis, j);
137
+ }
138
+ }
139
+ }
140
+
141
+ template <class C>
142
+ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
143
+ size_t nbits,
144
+ const uint8_t *codes,
145
+ size_t ncodes,
146
+ const float *dis_table,
147
+ size_t k,
148
+ float *heap_dis,
149
+ int64_t *heap_ids)
150
+ {
151
+ const size_t M = pq.M;
152
+ const size_t ksub = pq.ksub;
153
+ for (size_t j = 0; j < ncodes; ++j) {
154
+ faiss::ProductQuantizer::PQDecoderGeneric decoder(
155
+ codes + j * pq.code_size, nbits
156
+ );
157
+ float dis = 0;
158
+ const float * __restrict dt = dis_table;
159
+ for (size_t m = 0; m < M; m++) {
160
+ uint64_t c = decoder.decode();
161
+ dis += dt[c];
162
+ dt += ksub;
163
+ }
164
+
165
+ if (C::cmp(heap_dis[0], dis)) {
166
+ heap_pop<C>(k, heap_dis, heap_ids);
167
+ heap_push<C>(k, heap_dis, heap_ids, dis, j);
168
+ }
169
+ }
170
+ }
171
+
172
+ /*********************************************
173
+ * PQ implementation
174
+ *********************************************/
175
+
176
+
177
+
178
+ ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
179
+ d(d), M(M), nbits(nbits), assign_index(nullptr)
180
+ {
181
+ set_derived_values ();
182
+ }
183
+
184
+ ProductQuantizer::ProductQuantizer ()
185
+ : ProductQuantizer(0, 1, 0) {}
186
+
187
+ void ProductQuantizer::set_derived_values () {
188
+ // quite a few derived values
189
+ FAISS_THROW_IF_NOT (d % M == 0);
190
+ dsub = d / M;
191
+ code_size = (nbits * M + 7) / 8;
192
+ ksub = 1 << nbits;
193
+ centroids.resize (d * ksub);
194
+ verbose = false;
195
+ train_type = Train_default;
196
+ }
197
+
198
+ void ProductQuantizer::set_params (const float * centroids_, int m)
199
+ {
200
+ memcpy (get_centroids(m, 0), centroids_,
201
+ ksub * dsub * sizeof (centroids_[0]));
202
+ }
203
+
204
+
205
+ static void init_hypercube (int d, int nbits,
206
+ int n, const float * x,
207
+ float *centroids)
208
+ {
209
+
210
+ std::vector<float> mean (d);
211
+ for (int i = 0; i < n; i++)
212
+ for (int j = 0; j < d; j++)
213
+ mean [j] += x[i * d + j];
214
+
215
+ float maxm = 0;
216
+ for (int j = 0; j < d; j++) {
217
+ mean [j] /= n;
218
+ if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
219
+ }
220
+
221
+ for (int i = 0; i < (1 << nbits); i++) {
222
+ float * cent = centroids + i * d;
223
+ for (int j = 0; j < nbits; j++)
224
+ cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
225
+ for (int j = nbits; j < d; j++)
226
+ cent[j] = mean [j];
227
+ }
228
+
229
+
230
+ }
231
+
232
+ static void init_hypercube_pca (int d, int nbits,
233
+ int n, const float * x,
234
+ float *centroids)
235
+ {
236
+ PCAMatrix pca (d, nbits);
237
+ pca.train (n, x);
238
+
239
+
240
+ for (int i = 0; i < (1 << nbits); i++) {
241
+ float * cent = centroids + i * d;
242
+ for (int j = 0; j < d; j++) {
243
+ cent[j] = pca.mean[j];
244
+ float f = 1.0;
245
+ for (int k = 0; k < nbits; k++)
246
+ cent[j] += f *
247
+ sqrt (pca.eigenvalues [k]) *
248
+ (((i >> k) & 1) ? 1 : -1) *
249
+ pca.PCAMat [j + k * d];
250
+ }
251
+ }
252
+
253
+ }
254
+
255
+ void ProductQuantizer::train (int n, const float * x)
256
+ {
257
+ if (train_type != Train_shared) {
258
+ train_type_t final_train_type;
259
+ final_train_type = train_type;
260
+ if (train_type == Train_hypercube ||
261
+ train_type == Train_hypercube_pca) {
262
+ if (dsub < nbits) {
263
+ final_train_type = Train_default;
264
+ printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
265
+ nbits, dsub);
266
+ }
267
+ }
268
+
269
+ float * xslice = new float[n * dsub];
270
+ ScopeDeleter<float> del (xslice);
271
+ for (int m = 0; m < M; m++) {
272
+ for (int j = 0; j < n; j++)
273
+ memcpy (xslice + j * dsub,
274
+ x + j * d + m * dsub,
275
+ dsub * sizeof(float));
276
+
277
+ Clustering clus (dsub, ksub, cp);
278
+
279
+ // we have some initialization for the centroids
280
+ if (final_train_type != Train_default) {
281
+ clus.centroids.resize (dsub * ksub);
282
+ }
283
+
284
+ switch (final_train_type) {
285
+ case Train_hypercube:
286
+ init_hypercube (dsub, nbits, n, xslice,
287
+ clus.centroids.data ());
288
+ break;
289
+ case Train_hypercube_pca:
290
+ init_hypercube_pca (dsub, nbits, n, xslice,
291
+ clus.centroids.data ());
292
+ break;
293
+ case Train_hot_start:
294
+ memcpy (clus.centroids.data(),
295
+ get_centroids (m, 0),
296
+ dsub * ksub * sizeof (float));
297
+ break;
298
+ default: ;
299
+ }
300
+
301
+ if(verbose) {
302
+ clus.verbose = true;
303
+ printf ("Training PQ slice %d/%zd\n", m, M);
304
+ }
305
+ IndexFlatL2 index (dsub);
306
+ clus.train (n, xslice, assign_index ? *assign_index : index);
307
+ set_params (clus.centroids.data(), m);
308
+ }
309
+
310
+
311
+ } else {
312
+
313
+ Clustering clus (dsub, ksub, cp);
314
+
315
+ if(verbose) {
316
+ clus.verbose = true;
317
+ printf ("Training all PQ slices at once\n");
318
+ }
319
+
320
+ IndexFlatL2 index (dsub);
321
+
322
+ clus.train (n * M, x, assign_index ? *assign_index : index);
323
+ for (int m = 0; m < M; m++) {
324
+ set_params (clus.centroids.data(), m);
325
+ }
326
+
327
+ }
328
+ }
329
+
330
+ template<class PQEncoder>
331
+ void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) {
332
+ float distances [pq.ksub];
333
+ PQEncoder encoder(code, pq.nbits);
334
+ for (size_t m = 0; m < pq.M; m++) {
335
+ float mindis = 1e20;
336
+ uint64_t idxm = 0;
337
+ const float * xsub = x + m * pq.dsub;
338
+
339
+ fvec_L2sqr_ny(distances, xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
340
+
341
+ /* Find best centroid */
342
+ for (size_t i = 0; i < pq.ksub; i++) {
343
+ float dis = distances[i];
344
+ if (dis < mindis) {
345
+ mindis = dis;
346
+ idxm = i;
347
+ }
348
+ }
349
+
350
+ encoder.encode(idxm);
351
+ }
352
+ }
353
+
354
+ void ProductQuantizer::compute_code(const float * x, uint8_t * code) const {
355
+ switch (nbits) {
356
+ case 8:
357
+ faiss::compute_code<PQEncoder8>(*this, x, code);
358
+ break;
359
+
360
+ case 16:
361
+ faiss::compute_code<PQEncoder16>(*this, x, code);
362
+ break;
363
+
364
+ default:
365
+ faiss::compute_code<PQEncoderGeneric>(*this, x, code);
366
+ break;
367
+ }
368
+ }
369
+
370
+ template<class PQDecoder>
371
+ void decode(const ProductQuantizer& pq, const uint8_t *code, float *x)
372
+ {
373
+ PQDecoder decoder(code, pq.nbits);
374
+ for (size_t m = 0; m < pq.M; m++) {
375
+ uint64_t c = decoder.decode();
376
+ memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub);
377
+ }
378
+ }
379
+
380
+ void ProductQuantizer::decode (const uint8_t *code, float *x) const
381
+ {
382
+ switch (nbits) {
383
+ case 8:
384
+ faiss::decode<PQDecoder8>(*this, code, x);
385
+ break;
386
+
387
+ case 16:
388
+ faiss::decode<PQDecoder16>(*this, code, x);
389
+ break;
390
+
391
+ default:
392
+ faiss::decode<PQDecoderGeneric>(*this, code, x);
393
+ break;
394
+ }
395
+ }
396
+
397
+
398
+ void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
399
+ {
400
+ for (size_t i = 0; i < n; i++) {
401
+ this->decode (code + code_size * i, x + d * i);
402
+ }
403
+ }
404
+
405
+
406
+ void ProductQuantizer::compute_code_from_distance_table (const float *tab,
407
+ uint8_t *code) const
408
+ {
409
+ PQEncoderGeneric encoder(code, nbits);
410
+ for (size_t m = 0; m < M; m++) {
411
+ float mindis = 1e20;
412
+ uint64_t idxm = 0;
413
+
414
+ /* Find best centroid */
415
+ for (size_t j = 0; j < ksub; j++) {
416
+ float dis = *tab++;
417
+ if (dis < mindis) {
418
+ mindis = dis;
419
+ idxm = j;
420
+ }
421
+ }
422
+
423
+ encoder.encode(idxm);
424
+ }
425
+ }
426
+
427
+ void ProductQuantizer::compute_codes_with_assign_index (
428
+ const float * x,
429
+ uint8_t * codes,
430
+ size_t n)
431
+ {
432
+ FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
433
+
434
+ for (size_t m = 0; m < M; m++) {
435
+ assign_index->reset ();
436
+ assign_index->add (ksub, get_centroids (m, 0));
437
+ size_t bs = 65536;
438
+ float * xslice = new float[bs * dsub];
439
+ ScopeDeleter<float> del (xslice);
440
+ idx_t *assign = new idx_t[bs];
441
+ ScopeDeleter<idx_t> del2 (assign);
442
+
443
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
444
+ size_t i1 = std::min(i0 + bs, n);
445
+
446
+ for (size_t i = i0; i < i1; i++) {
447
+ memcpy (xslice + (i - i0) * dsub,
448
+ x + i * d + m * dsub,
449
+ dsub * sizeof(float));
450
+ }
451
+
452
+ assign_index->assign (i1 - i0, xslice, assign);
453
+
454
+ if (nbits == 8) {
455
+ uint8_t *c = codes + code_size * i0 + m;
456
+ for (size_t i = i0; i < i1; i++) {
457
+ *c = assign[i - i0];
458
+ c += M;
459
+ }
460
+ } else if (nbits == 16) {
461
+ uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2);
462
+ for (size_t i = i0; i < i1; i++) {
463
+ *c = assign[i - i0];
464
+ c += M;
465
+ }
466
+ } else {
467
+ for (size_t i = i0; i < i1; ++i) {
468
+ uint8_t *c = codes + code_size * i + ((m * nbits) / 8);
469
+ uint8_t offset = (m * nbits) % 8;
470
+ uint64_t ass = assign[i - i0];
471
+
472
+ PQEncoderGeneric encoder(c, nbits, offset);
473
+ encoder.encode(ass);
474
+ }
475
+ }
476
+
477
+ }
478
+ }
479
+
480
+ }
481
+
482
+ void ProductQuantizer::compute_codes (const float * x,
483
+ uint8_t * codes,
484
+ size_t n) const
485
+ {
486
+ // process by blocks to avoid using too much RAM
487
+ size_t bs = 256 * 1024;
488
+ if (n > bs) {
489
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
490
+ size_t i1 = std::min(i0 + bs, n);
491
+ compute_codes (x + d * i0, codes + code_size * i0, i1 - i0);
492
+ }
493
+ return;
494
+ }
495
+
496
+ if (dsub < 16) { // simple direct computation
497
+
498
+ #pragma omp parallel for
499
+ for (size_t i = 0; i < n; i++)
500
+ compute_code (x + i * d, codes + i * code_size);
501
+
502
+ } else { // worthwile to use BLAS
503
+ float *dis_tables = new float [n * ksub * M];
504
+ ScopeDeleter<float> del (dis_tables);
505
+ compute_distance_tables (n, x, dis_tables);
506
+
507
+ #pragma omp parallel for
508
+ for (size_t i = 0; i < n; i++) {
509
+ uint8_t * code = codes + i * code_size;
510
+ const float * tab = dis_tables + i * ksub * M;
511
+ compute_code_from_distance_table (tab, code);
512
+ }
513
+ }
514
+ }
515
+
516
+
517
+ void ProductQuantizer::compute_distance_table (const float * x,
518
+ float * dis_table) const
519
+ {
520
+ size_t m;
521
+
522
+ for (m = 0; m < M; m++) {
523
+ fvec_L2sqr_ny (dis_table + m * ksub,
524
+ x + m * dsub,
525
+ get_centroids(m, 0),
526
+ dsub,
527
+ ksub);
528
+ }
529
+ }
530
+
531
+ void ProductQuantizer::compute_inner_prod_table (const float * x,
532
+ float * dis_table) const
533
+ {
534
+ size_t m;
535
+
536
+ for (m = 0; m < M; m++) {
537
+ fvec_inner_products_ny (dis_table + m * ksub,
538
+ x + m * dsub,
539
+ get_centroids(m, 0),
540
+ dsub,
541
+ ksub);
542
+ }
543
+ }
544
+
545
+
546
+ void ProductQuantizer::compute_distance_tables (
547
+ size_t nx,
548
+ const float * x,
549
+ float * dis_tables) const
550
+ {
551
+
552
+ if (dsub < 16) {
553
+
554
+ #pragma omp parallel for
555
+ for (size_t i = 0; i < nx; i++) {
556
+ compute_distance_table (x + i * d, dis_tables + i * ksub * M);
557
+ }
558
+
559
+ } else { // use BLAS
560
+
561
+ for (int m = 0; m < M; m++) {
562
+ pairwise_L2sqr (dsub,
563
+ nx, x + dsub * m,
564
+ ksub, centroids.data() + m * dsub * ksub,
565
+ dis_tables + ksub * m,
566
+ d, dsub, ksub * M);
567
+ }
568
+ }
569
+ }
570
+
571
+ void ProductQuantizer::compute_inner_prod_tables (
572
+ size_t nx,
573
+ const float * x,
574
+ float * dis_tables) const
575
+ {
576
+
577
+ if (dsub < 16) {
578
+
579
+ #pragma omp parallel for
580
+ for (size_t i = 0; i < nx; i++) {
581
+ compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
582
+ }
583
+
584
+ } else { // use BLAS
585
+
586
+ // compute distance tables
587
+ for (int m = 0; m < M; m++) {
588
+ FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
589
+ dsubi = dsub, di = d;
590
+ float one = 1.0, zero = 0;
591
+
592
+ sgemm_ ("Transposed", "Not transposed",
593
+ &ksubi, &nxi, &dsubi,
594
+ &one, &centroids [m * dsub * ksub], &dsubi,
595
+ x + dsub * m, &di,
596
+ &zero, dis_tables + ksub * m, &ldc);
597
+ }
598
+
599
+ }
600
+ }
601
+
602
+ template <class C>
603
+ static void pq_knn_search_with_tables (
604
+ const ProductQuantizer& pq,
605
+ size_t nbits,
606
+ const float *dis_tables,
607
+ const uint8_t * codes,
608
+ const size_t ncodes,
609
+ HeapArray<C> * res,
610
+ bool init_finalize_heap)
611
+ {
612
+ size_t k = res->k, nx = res->nh;
613
+ size_t ksub = pq.ksub, M = pq.M;
614
+
615
+
616
+ #pragma omp parallel for
617
+ for (size_t i = 0; i < nx; i++) {
618
+ /* query preparation for asymmetric search: compute look-up tables */
619
+ const float* dis_table = dis_tables + i * ksub * M;
620
+
621
+ /* Compute distances and keep smallest values */
622
+ int64_t * __restrict heap_ids = res->ids + i * k;
623
+ float * __restrict heap_dis = res->val + i * k;
624
+
625
+ if (init_finalize_heap) {
626
+ heap_heapify<C> (k, heap_dis, heap_ids);
627
+ }
628
+
629
+ switch (nbits) {
630
+ case 8:
631
+ pq_estimators_from_tables<uint8_t, C> (pq,
632
+ codes, ncodes,
633
+ dis_table,
634
+ k, heap_dis, heap_ids);
635
+ break;
636
+
637
+ case 16:
638
+ pq_estimators_from_tables<uint16_t, C> (pq,
639
+ (uint16_t*)codes, ncodes,
640
+ dis_table,
641
+ k, heap_dis, heap_ids);
642
+ break;
643
+
644
+ default:
645
+ pq_estimators_from_tables_generic<C> (pq,
646
+ nbits,
647
+ codes, ncodes,
648
+ dis_table,
649
+ k, heap_dis, heap_ids);
650
+ break;
651
+ }
652
+
653
+ if (init_finalize_heap) {
654
+ heap_reorder<C> (k, heap_dis, heap_ids);
655
+ }
656
+ }
657
+ }
658
+
659
+ void ProductQuantizer::search (const float * __restrict x,
660
+ size_t nx,
661
+ const uint8_t * codes,
662
+ const size_t ncodes,
663
+ float_maxheap_array_t * res,
664
+ bool init_finalize_heap) const
665
+ {
666
+ FAISS_THROW_IF_NOT (nx == res->nh);
667
+ std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
668
+ compute_distance_tables (nx, x, dis_tables.get());
669
+
670
+ pq_knn_search_with_tables<CMax<float, int64_t>> (
671
+ *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
672
+ }
673
+
674
+ void ProductQuantizer::search_ip (const float * __restrict x,
675
+ size_t nx,
676
+ const uint8_t * codes,
677
+ const size_t ncodes,
678
+ float_minheap_array_t * res,
679
+ bool init_finalize_heap) const
680
+ {
681
+ FAISS_THROW_IF_NOT (nx == res->nh);
682
+ std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
683
+ compute_inner_prod_tables (nx, x, dis_tables.get());
684
+
685
+ pq_knn_search_with_tables<CMin<float, int64_t> > (
686
+ *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
687
+ }
688
+
689
+
690
+
691
+ static float sqr (float x) {
692
+ return x * x;
693
+ }
694
+
695
+ void ProductQuantizer::compute_sdc_table ()
696
+ {
697
+ sdc_table.resize (M * ksub * ksub);
698
+
699
+ for (int m = 0; m < M; m++) {
700
+
701
+ const float *cents = centroids.data() + m * ksub * dsub;
702
+ float * dis_tab = sdc_table.data() + m * ksub * ksub;
703
+
704
+ // TODO optimize with BLAS
705
+ for (int i = 0; i < ksub; i++) {
706
+ const float *centi = cents + i * dsub;
707
+ for (int j = 0; j < ksub; j++) {
708
+ float accu = 0;
709
+ const float *centj = cents + j * dsub;
710
+ for (int k = 0; k < dsub; k++)
711
+ accu += sqr (centi[k] - centj[k]);
712
+ dis_tab [i + j * ksub] = accu;
713
+ }
714
+ }
715
+ }
716
+ }
717
+
718
+ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
719
+ size_t nq,
720
+ const uint8_t * bcodes,
721
+ const size_t nb,
722
+ float_maxheap_array_t * res,
723
+ bool init_finalize_heap) const
724
+ {
725
+ FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
726
+ FAISS_THROW_IF_NOT (nbits == 8);
727
+ size_t k = res->k;
728
+
729
+
730
+ #pragma omp parallel for
731
+ for (size_t i = 0; i < nq; i++) {
732
+
733
+ /* Compute distances and keep smallest values */
734
+ idx_t * heap_ids = res->ids + i * k;
735
+ float * heap_dis = res->val + i * k;
736
+ const uint8_t * qcode = qcodes + i * code_size;
737
+
738
+ if (init_finalize_heap)
739
+ maxheap_heapify (k, heap_dis, heap_ids);
740
+
741
+ const uint8_t * bcode = bcodes;
742
+ for (size_t j = 0; j < nb; j++) {
743
+ float dis = 0;
744
+ const float * tab = sdc_table.data();
745
+ for (int m = 0; m < M; m++) {
746
+ dis += tab[bcode[m] + qcode[m] * ksub];
747
+ tab += ksub * ksub;
748
+ }
749
+ if (dis < heap_dis[0]) {
750
+ maxheap_pop (k, heap_dis, heap_ids);
751
+ maxheap_push (k, heap_dis, heap_ids, dis, j);
752
+ }
753
+ bcode += code_size;
754
+ }
755
+
756
+ if (init_finalize_heap)
757
+ maxheap_reorder (k, heap_dis, heap_ids);
758
+ }
759
+
760
+ }
761
+
762
+
763
+ ProductQuantizer::PQEncoderGeneric::PQEncoderGeneric(uint8_t *code, int nbits,
764
+ uint8_t offset)
765
+ : code(code), offset(offset), nbits(nbits), reg(0) {
766
+ assert(nbits <= 64);
767
+ if (offset > 0) {
768
+ reg = (*code & ((1 << offset) - 1));
769
+ }
770
+ }
771
+
772
+ void ProductQuantizer::PQEncoderGeneric::encode(uint64_t x) {
773
+ reg |= (uint8_t)(x << offset);
774
+ x >>= (8 - offset);
775
+ if (offset + nbits >= 8) {
776
+ *code++ = reg;
777
+
778
+ for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
779
+ *code++ = (uint8_t)x;
780
+ x >>= 8;
781
+ }
782
+
783
+ offset += nbits;
784
+ offset &= 7;
785
+ reg = (uint8_t)x;
786
+ } else {
787
+ offset += nbits;
788
+ }
789
+ }
790
+
791
+ ProductQuantizer::PQEncoderGeneric::~PQEncoderGeneric() {
792
+ if (offset > 0) {
793
+ *code = reg;
794
+ }
795
+ }
796
+
797
+
798
+ ProductQuantizer::PQEncoder8::PQEncoder8(uint8_t *code, int nbits)
799
+ : code(code) {
800
+ assert(8 == nbits);
801
+ }
802
+
803
+ void ProductQuantizer::PQEncoder8::encode(uint64_t x) {
804
+ *code++ = (uint8_t)x;
805
+ }
806
+
807
+
808
+ ProductQuantizer::PQEncoder16::PQEncoder16(uint8_t *code, int nbits)
809
+ : code((uint16_t *)code) {
810
+ assert(16 == nbits);
811
+ }
812
+
813
+ void ProductQuantizer::PQEncoder16::encode(uint64_t x) {
814
+ *code++ = (uint16_t)x;
815
+ }
816
+
817
+
818
+ ProductQuantizer::PQDecoderGeneric::PQDecoderGeneric(const uint8_t *code,
819
+ int nbits)
820
+ : code(code),
821
+ offset(0),
822
+ nbits(nbits),
823
+ mask((1ull << nbits) - 1),
824
+ reg(0) {
825
+ assert(nbits <= 64);
826
+ }
827
+
828
+ uint64_t ProductQuantizer::PQDecoderGeneric::decode() {
829
+ if (offset == 0) {
830
+ reg = *code;
831
+ }
832
+ uint64_t c = (reg >> offset);
833
+
834
+ if (offset + nbits >= 8) {
835
+ uint64_t e = 8 - offset;
836
+ ++code;
837
+ for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
838
+ c |= ((uint64_t)(*code++) << e);
839
+ e += 8;
840
+ }
841
+
842
+ offset += nbits;
843
+ offset &= 7;
844
+ if (offset > 0) {
845
+ reg = *code;
846
+ c |= ((uint64_t)reg << e);
847
+ }
848
+ } else {
849
+ offset += nbits;
850
+ }
851
+
852
+ return c & mask;
853
+ }
854
+
855
+
856
+ ProductQuantizer::PQDecoder8::PQDecoder8(const uint8_t *code, int nbits)
857
+ : code(code) {
858
+ assert(8 == nbits);
859
+ }
860
+
861
+ uint64_t ProductQuantizer::PQDecoder8::decode() {
862
+ return (uint64_t)(*code++);
863
+ }
864
+
865
+
866
+ ProductQuantizer::PQDecoder16::PQDecoder16(const uint8_t *code, int nbits)
867
+ : code((uint16_t *)code) {
868
+ assert(16 == nbits);
869
+ }
870
+
871
+ uint64_t ProductQuantizer::PQDecoder16::decode() {
872
+ return (uint64_t)(*code++);
873
+ }
874
+
875
+
876
+ } // namespace faiss