faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -10,62 +10,56 @@
10
10
  #ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED
11
11
  #define FAISS_POLYSEMOUS_TRAINING_INCLUDED
12
12
 
13
-
14
13
  #include <faiss/impl/ProductQuantizer.h>
15
14
 
16
-
17
15
  namespace faiss {
18
16
 
19
-
20
17
  /// parameters used for the simulated annealing method
21
18
  struct SimulatedAnnealingParameters {
22
-
23
19
  // optimization parameters
24
- double init_temperature; // init probability of accepting a bad swap
25
- double temperature_decay; // at each iteration the temp is multiplied by this
26
- int n_iter; // nb of iterations
27
- int n_redo; // nb of runs of the simulation
28
- int seed; // random seed
20
+ double init_temperature; // init probability of accepting a bad swap
21
+ double temperature_decay; // at each iteration the temp is multiplied by
22
+ // this
23
+ int n_iter; // nb of iterations
24
+ int n_redo; // nb of runs of the simulation
25
+ int seed; // random seed
29
26
  int verbose;
30
27
  bool only_bit_flips; // restrict permutation changes to bit flips
31
- bool init_random; // initialize with a random permutation (not identity)
28
+ bool init_random; // initialize with a random permutation (not identity)
32
29
 
33
30
  // set reasonable defaults
34
- SimulatedAnnealingParameters ();
35
-
31
+ SimulatedAnnealingParameters();
36
32
  };
37
33
 
38
-
39
34
  /// abstract class for the loss function
40
35
  struct PermutationObjective {
41
-
42
36
  int n;
43
37
 
44
- virtual double compute_cost (const int *perm) const = 0;
38
+ virtual double compute_cost(const int* perm) const = 0;
45
39
 
46
40
  // what would the cost update be if iw and jw were swapped?
47
41
  // default implementation just computes both and computes the difference
48
- virtual double cost_update (const int *perm, int iw, int jw) const;
42
+ virtual double cost_update(const int* perm, int iw, int jw) const;
49
43
 
50
- virtual ~PermutationObjective () {}
44
+ virtual ~PermutationObjective() {}
51
45
  };
52
46
 
53
-
54
47
  struct ReproduceDistancesObjective : PermutationObjective {
55
-
56
48
  double dis_weight_factor;
57
49
 
58
- static double sqr (double x) { return x * x; }
50
+ static double sqr(double x) {
51
+ return x * x;
52
+ }
59
53
 
60
54
  // weighting of distances: it is more important to reproduce small
61
55
  // distances well
62
- double dis_weight (double x) const;
56
+ double dis_weight(double x) const;
63
57
 
64
58
  std::vector<double> source_dis; ///< "real" corrected distances (size n^2)
65
- const double * target_dis; ///< wanted distances (size n^2)
59
+ const double* target_dis; ///< wanted distances (size n^2)
66
60
  std::vector<double> weights; ///< weights for each distance (size n^2)
67
61
 
68
- double get_source_dis (int i, int j) const;
62
+ double get_source_dis(int i, int j) const;
69
63
 
70
64
  // cost = quadratic difference between actual distance and Hamming distance
71
65
  double compute_cost(const int* perm) const override;
@@ -74,16 +68,19 @@ struct ReproduceDistancesObjective : PermutationObjective {
74
68
  // computed in O(n) instead of O(n^2) for the full re-computation
75
69
  double cost_update(const int* perm, int iw, int jw) const override;
76
70
 
77
- ReproduceDistancesObjective (
78
- int n,
79
- const double *source_dis_in,
80
- const double *target_dis_in,
81
- double dis_weight_factor);
71
+ ReproduceDistancesObjective(
72
+ int n,
73
+ const double* source_dis_in,
74
+ const double* target_dis_in,
75
+ double dis_weight_factor);
82
76
 
83
- static void compute_mean_stdev (const double *tab, size_t n2,
84
- double *mean_out, double *stddev_out);
77
+ static void compute_mean_stdev(
78
+ const double* tab,
79
+ size_t n2,
80
+ double* mean_out,
81
+ double* stddev_out);
85
82
 
86
- void set_affine_target_dis (const double *source_dis_in);
83
+ void set_affine_target_dis(const double* source_dis_in);
87
84
 
88
85
  ~ReproduceDistancesObjective() override {}
89
86
  };
@@ -91,39 +88,36 @@ struct ReproduceDistancesObjective : PermutationObjective {
91
88
  struct RandomGenerator;
92
89
 
93
90
  /// Simulated annealing optimization algorithm for permutations.
94
- struct SimulatedAnnealingOptimizer: SimulatedAnnealingParameters {
95
-
96
- PermutationObjective *obj;
91
+ struct SimulatedAnnealingOptimizer : SimulatedAnnealingParameters {
92
+ PermutationObjective* obj;
97
93
  int n; ///< size of the permutation
98
- FILE *logfile; /// logs values of the cost function
94
+ FILE* logfile; /// logs values of the cost function
99
95
 
100
- SimulatedAnnealingOptimizer (PermutationObjective *obj,
101
- const SimulatedAnnealingParameters &p);
102
- RandomGenerator *rnd;
96
+ SimulatedAnnealingOptimizer(
97
+ PermutationObjective* obj,
98
+ const SimulatedAnnealingParameters& p);
99
+ RandomGenerator* rnd;
103
100
 
104
101
  /// remember initial cost of optimization
105
102
  double init_cost;
106
103
 
107
104
  // main entry point. Perform the optimization loop, starting from
108
105
  // and modifying permutation in-place
109
- double optimize (int *perm);
106
+ double optimize(int* perm);
110
107
 
111
108
  // run the optimization and return the best result in best_perm
112
- double run_optimization (int * best_perm);
109
+ double run_optimization(int* best_perm);
113
110
 
114
- virtual ~SimulatedAnnealingOptimizer ();
111
+ virtual ~SimulatedAnnealingOptimizer();
115
112
  };
116
113
 
117
-
118
-
119
-
120
114
  /// optimizes the order of indices in a ProductQuantizer
121
- struct PolysemousTraining: SimulatedAnnealingParameters {
122
-
115
+ struct PolysemousTraining : SimulatedAnnealingParameters {
123
116
  enum Optimization_type_t {
124
117
  OT_None,
125
- OT_ReproduceDistances_affine, ///< default
126
- OT_Ranking_weighted_diff ///< same as _2, but use rank of y+ - rank of y-
118
+ OT_ReproduceDistances_affine, ///< default
119
+ OT_Ranking_weighted_diff ///< same as _2, but use rank of y+ - rank of
120
+ ///< y-
127
121
  };
128
122
  Optimization_type_t optimization_type;
129
123
 
@@ -133,26 +127,29 @@ struct PolysemousTraining: SimulatedAnnealingParameters {
133
127
  int ntrain_permutation;
134
128
  double dis_weight_factor; ///< decay of exp that weights distance loss
135
129
 
130
+ /// refuse to train if it would require more than that amount of RAM
131
+ size_t max_memory;
132
+
136
133
  // filename pattern for the logging of iterations
137
134
  std::string log_pattern;
138
135
 
139
136
  // sets default values
140
- PolysemousTraining ();
137
+ PolysemousTraining();
141
138
 
142
139
  /// reorder the centroids so that the Hamming distance becomes a
143
140
  /// good approximation of the SDC distance (called by train)
144
- void optimize_pq_for_hamming (ProductQuantizer & pq,
145
- size_t n, const float *x) const;
141
+ void optimize_pq_for_hamming(ProductQuantizer& pq, size_t n, const float* x)
142
+ const;
146
143
 
147
144
  /// called by optimize_pq_for_hamming
148
- void optimize_ranking (ProductQuantizer &pq, size_t n, const float *x) const;
145
+ void optimize_ranking(ProductQuantizer& pq, size_t n, const float* x) const;
149
146
  /// called by optimize_pq_for_hamming
150
- void optimize_reproduce_distances (ProductQuantizer &pq) const;
147
+ void optimize_reproduce_distances(ProductQuantizer& pq) const;
151
148
 
149
+ /// make sure we don't blow up the memory
150
+ size_t memory_usage_per_thread(const ProductQuantizer& pq) const;
152
151
  };
153
152
 
154
-
155
153
  } // namespace faiss
156
154
 
157
-
158
155
  #endif
@@ -7,20 +7,18 @@
7
7
 
8
8
  namespace faiss {
9
9
 
10
- inline
11
- PQEncoderGeneric::PQEncoderGeneric(uint8_t *code, int nbits,
12
- uint8_t offset)
13
- : code(code), offset(offset), nbits(nbits), reg(0)
14
- {
10
+ inline PQEncoderGeneric::PQEncoderGeneric(
11
+ uint8_t* code,
12
+ int nbits,
13
+ uint8_t offset)
14
+ : code(code), offset(offset), nbits(nbits), reg(0) {
15
15
  assert(nbits <= 64);
16
16
  if (offset > 0) {
17
17
  reg = (*code & ((1 << offset) - 1));
18
18
  }
19
19
  }
20
20
 
21
- inline
22
- void PQEncoderGeneric::encode(uint64_t x)
23
- {
21
+ inline void PQEncoderGeneric::encode(uint64_t x) {
24
22
  reg |= (uint8_t)(x << offset);
25
23
  x >>= (8 - offset);
26
24
  if (offset + nbits >= 8) {
@@ -39,51 +37,39 @@ void PQEncoderGeneric::encode(uint64_t x)
39
37
  }
40
38
  }
41
39
 
42
- inline
43
- PQEncoderGeneric::~PQEncoderGeneric()
44
- {
40
+ inline PQEncoderGeneric::~PQEncoderGeneric() {
45
41
  if (offset > 0) {
46
42
  *code = reg;
47
43
  }
48
44
  }
49
45
 
50
-
51
- inline
52
- PQEncoder8::PQEncoder8(uint8_t *code, int nbits)
53
- : code(code) {
46
+ inline PQEncoder8::PQEncoder8(uint8_t* code, int nbits) : code(code) {
54
47
  assert(8 == nbits);
55
48
  }
56
49
 
57
- inline
58
- void PQEncoder8::encode(uint64_t x) {
50
+ inline void PQEncoder8::encode(uint64_t x) {
59
51
  *code++ = (uint8_t)x;
60
52
  }
61
53
 
62
- inline
63
- PQEncoder16::PQEncoder16(uint8_t *code, int nbits)
64
- : code((uint16_t *)code) {
54
+ inline PQEncoder16::PQEncoder16(uint8_t* code, int nbits)
55
+ : code((uint16_t*)code) {
65
56
  assert(16 == nbits);
66
57
  }
67
58
 
68
- inline
69
- void PQEncoder16::encode(uint64_t x) {
59
+ inline void PQEncoder16::encode(uint64_t x) {
70
60
  *code++ = (uint16_t)x;
71
61
  }
72
62
 
73
-
74
- inline
75
- PQDecoderGeneric::PQDecoderGeneric(const uint8_t *code,
76
- int nbits)
77
- : code(code),
78
- offset(0),
79
- nbits(nbits),
80
- mask((1ull << nbits) - 1),
81
- reg(0) {
63
+ inline PQDecoderGeneric::PQDecoderGeneric(const uint8_t* code, int nbits)
64
+ : code(code),
65
+ offset(0),
66
+ nbits(nbits),
67
+ mask((1ull << nbits) - 1),
68
+ reg(0) {
82
69
  assert(nbits <= 64);
83
70
  }
84
71
 
85
- inline
86
- uint64_t PQDecoderGeneric::decode() {
72
+ inline uint64_t PQDecoderGeneric::decode() {
87
73
  if (offset == 0) {
88
74
  reg = *code;
89
75
  }
@@ -110,27 +96,20 @@ uint64_t PQDecoderGeneric::decode() {
110
96
  return c & mask;
111
97
  }
112
98
 
113
-
114
- inline
115
- PQDecoder8::PQDecoder8(const uint8_t *code, int nbits)
116
- : code(code) {
117
- assert(8 == nbits);
99
+ inline PQDecoder8::PQDecoder8(const uint8_t* code, int nbits_in) : code(code) {
100
+ assert(8 == nbits_in);
118
101
  }
119
102
 
120
- inline
121
- uint64_t PQDecoder8::decode() {
103
+ inline uint64_t PQDecoder8::decode() {
122
104
  return (uint64_t)(*code++);
123
105
  }
124
106
 
125
-
126
- inline
127
- PQDecoder16::PQDecoder16(const uint8_t *code, int nbits)
128
- : code((uint16_t *)code) {
129
- assert(16 == nbits);
107
+ inline PQDecoder16::PQDecoder16(const uint8_t* code, int nbits_in)
108
+ : code((uint16_t*)code) {
109
+ assert(16 == nbits_in);
130
110
  }
131
111
 
132
- inline
133
- uint64_t PQDecoder16::decode() {
112
+ inline uint64_t PQDecoder16::decode() {
134
113
  return (uint64_t)(*code++);
135
114
  }
136
115
 
@@ -9,113 +9,118 @@
9
9
 
10
10
  #include <faiss/impl/ProductQuantizer.h>
11
11
 
12
-
13
12
  #include <cstddef>
14
- #include <cstring>
15
13
  #include <cstdio>
14
+ #include <cstring>
16
15
  #include <memory>
17
16
 
18
17
  #include <algorithm>
19
18
 
20
- #include <faiss/impl/FaissAssert.h>
21
- #include <faiss/VectorTransform.h>
22
19
  #include <faiss/IndexFlat.h>
20
+ #include <faiss/VectorTransform.h>
21
+ #include <faiss/impl/FaissAssert.h>
23
22
  #include <faiss/utils/distances.h>
24
23
 
25
-
26
24
  extern "C" {
27
25
 
28
26
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
29
27
 
30
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
31
- n, FINTEGER *k, const float *alpha, const float *a,
32
- FINTEGER *lda, const float *b, FINTEGER *
33
- ldb, float *beta, float *c, FINTEGER *ldc);
34
-
28
+ int sgemm_(
29
+ const char* transa,
30
+ const char* transb,
31
+ FINTEGER* m,
32
+ FINTEGER* n,
33
+ FINTEGER* k,
34
+ const float* alpha,
35
+ const float* a,
36
+ FINTEGER* lda,
37
+ const float* b,
38
+ FINTEGER* ldb,
39
+ float* beta,
40
+ float* c,
41
+ FINTEGER* ldc);
35
42
  }
36
43
 
37
-
38
44
  namespace faiss {
39
45
 
40
-
41
46
  /* compute an estimator using look-up tables for typical values of M */
42
47
  template <typename CT, class C>
43
- void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
44
- size_t ncodes,
45
- const float * __restrict dis_table,
46
- size_t ksub,
47
- size_t k,
48
- float * heap_dis,
49
- int64_t * heap_ids)
50
- {
51
-
48
+ void pq_estimators_from_tables_Mmul4(
49
+ int M,
50
+ const CT* codes,
51
+ size_t ncodes,
52
+ const float* __restrict dis_table,
53
+ size_t ksub,
54
+ size_t k,
55
+ float* heap_dis,
56
+ int64_t* heap_ids) {
52
57
  for (size_t j = 0; j < ncodes; j++) {
53
58
  float dis = 0;
54
- const float *dt = dis_table;
59
+ const float* dt = dis_table;
55
60
 
56
- for (size_t m = 0; m < M; m+=4) {
61
+ for (size_t m = 0; m < M; m += 4) {
57
62
  float dism = 0;
58
- dism = dt[*codes++]; dt += ksub;
59
- dism += dt[*codes++]; dt += ksub;
60
- dism += dt[*codes++]; dt += ksub;
61
- dism += dt[*codes++]; dt += ksub;
63
+ dism = dt[*codes++];
64
+ dt += ksub;
65
+ dism += dt[*codes++];
66
+ dt += ksub;
67
+ dism += dt[*codes++];
68
+ dt += ksub;
69
+ dism += dt[*codes++];
70
+ dt += ksub;
62
71
  dis += dism;
63
72
  }
64
73
 
65
- if (C::cmp (heap_dis[0], dis)) {
66
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
74
+ if (C::cmp(heap_dis[0], dis)) {
75
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
67
76
  }
68
77
  }
69
78
  }
70
79
 
71
-
72
80
  template <typename CT, class C>
73
- void pq_estimators_from_tables_M4 (const CT * codes,
74
- size_t ncodes,
75
- const float * __restrict dis_table,
76
- size_t ksub,
77
- size_t k,
78
- float * heap_dis,
79
- int64_t * heap_ids)
80
- {
81
-
81
+ void pq_estimators_from_tables_M4(
82
+ const CT* codes,
83
+ size_t ncodes,
84
+ const float* __restrict dis_table,
85
+ size_t ksub,
86
+ size_t k,
87
+ float* heap_dis,
88
+ int64_t* heap_ids) {
82
89
  for (size_t j = 0; j < ncodes; j++) {
83
90
  float dis = 0;
84
- const float *dt = dis_table;
85
- dis = dt[*codes++]; dt += ksub;
86
- dis += dt[*codes++]; dt += ksub;
87
- dis += dt[*codes++]; dt += ksub;
91
+ const float* dt = dis_table;
92
+ dis = dt[*codes++];
93
+ dt += ksub;
94
+ dis += dt[*codes++];
95
+ dt += ksub;
96
+ dis += dt[*codes++];
97
+ dt += ksub;
88
98
  dis += dt[*codes++];
89
99
 
90
- if (C::cmp (heap_dis[0], dis)) {
91
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
100
+ if (C::cmp(heap_dis[0], dis)) {
101
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
92
102
  }
93
103
  }
94
104
  }
95
105
 
96
-
97
106
  template <typename CT, class C>
98
- static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
99
- const CT * codes,
100
- size_t ncodes,
101
- const float * dis_table,
102
- size_t k,
103
- float * heap_dis,
104
- int64_t * heap_ids)
105
- {
106
-
107
- if (pq.M == 4) {
108
-
109
- pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
110
- dis_table, pq.ksub, k,
111
- heap_dis, heap_ids);
107
+ static inline void pq_estimators_from_tables(
108
+ const ProductQuantizer& pq,
109
+ const CT* codes,
110
+ size_t ncodes,
111
+ const float* dis_table,
112
+ size_t k,
113
+ float* heap_dis,
114
+ int64_t* heap_ids) {
115
+ if (pq.M == 4) {
116
+ pq_estimators_from_tables_M4<CT, C>(
117
+ codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
112
118
  return;
113
119
  }
114
120
 
115
121
  if (pq.M % 4 == 0) {
116
- pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
117
- dis_table, pq.ksub, k,
118
- heap_dis, heap_ids);
122
+ pq_estimators_from_tables_Mmul4<CT, C>(
123
+ pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
119
124
  return;
120
125
  }
121
126
 
@@ -124,132 +129,124 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
124
129
  const size_t ksub = pq.ksub;
125
130
  for (size_t j = 0; j < ncodes; j++) {
126
131
  float dis = 0;
127
- const float * __restrict dt = dis_table;
132
+ const float* __restrict dt = dis_table;
128
133
  for (int m = 0; m < M; m++) {
129
134
  dis += dt[*codes++];
130
135
  dt += ksub;
131
136
  }
132
- if (C::cmp (heap_dis[0], dis)) {
133
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
137
+ if (C::cmp(heap_dis[0], dis)) {
138
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
134
139
  }
135
140
  }
136
141
  }
137
142
 
138
143
  template <class C>
139
- static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
140
- size_t nbits,
141
- const uint8_t *codes,
142
- size_t ncodes,
143
- const float *dis_table,
144
- size_t k,
145
- float *heap_dis,
146
- int64_t *heap_ids)
147
- {
148
- const size_t M = pq.M;
149
- const size_t ksub = pq.ksub;
150
- for (size_t j = 0; j < ncodes; ++j) {
151
- PQDecoderGeneric decoder(
152
- codes + j * pq.code_size, nbits
153
- );
154
- float dis = 0;
155
- const float * __restrict dt = dis_table;
156
- for (size_t m = 0; m < M; m++) {
157
- uint64_t c = decoder.decode();
158
- dis += dt[c];
159
- dt += ksub;
160
- }
144
+ static inline void pq_estimators_from_tables_generic(
145
+ const ProductQuantizer& pq,
146
+ size_t nbits,
147
+ const uint8_t* codes,
148
+ size_t ncodes,
149
+ const float* dis_table,
150
+ size_t k,
151
+ float* heap_dis,
152
+ int64_t* heap_ids) {
153
+ const size_t M = pq.M;
154
+ const size_t ksub = pq.ksub;
155
+ for (size_t j = 0; j < ncodes; ++j) {
156
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
157
+ float dis = 0;
158
+ const float* __restrict dt = dis_table;
159
+ for (size_t m = 0; m < M; m++) {
160
+ uint64_t c = decoder.decode();
161
+ dis += dt[c];
162
+ dt += ksub;
163
+ }
161
164
 
162
- if (C::cmp(heap_dis[0], dis)) {
163
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
165
+ if (C::cmp(heap_dis[0], dis)) {
166
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
167
+ }
164
168
  }
165
- }
166
169
  }
167
170
 
168
171
  /*********************************************
169
172
  * PQ implementation
170
173
  *********************************************/
171
174
 
172
-
173
-
174
- ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
175
- d(d), M(M), nbits(nbits), assign_index(nullptr)
176
- {
177
- set_derived_values ();
175
+ ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits)
176
+ : d(d), M(M), nbits(nbits), assign_index(nullptr) {
177
+ set_derived_values();
178
178
  }
179
179
 
180
- ProductQuantizer::ProductQuantizer ()
181
- : ProductQuantizer(0, 1, 0) {}
180
+ ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {}
182
181
 
183
- void ProductQuantizer::set_derived_values () {
182
+ void ProductQuantizer::set_derived_values() {
184
183
  // quite a few derived values
185
- FAISS_THROW_IF_NOT_MSG (d % M == 0, "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
184
+ FAISS_THROW_IF_NOT_MSG(
185
+ d % M == 0,
186
+ "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
186
187
  dsub = d / M;
187
188
  code_size = (nbits * M + 7) / 8;
188
189
  ksub = 1 << nbits;
189
- centroids.resize (d * ksub);
190
+ centroids.resize(d * ksub);
190
191
  verbose = false;
191
192
  train_type = Train_default;
192
193
  }
193
194
 
194
- void ProductQuantizer::set_params (const float * centroids_, int m)
195
- {
196
- memcpy (get_centroids(m, 0), centroids_,
197
- ksub * dsub * sizeof (centroids_[0]));
195
+ void ProductQuantizer::set_params(const float* centroids_, int m) {
196
+ memcpy(get_centroids(m, 0),
197
+ centroids_,
198
+ ksub * dsub * sizeof(centroids_[0]));
198
199
  }
199
200
 
200
-
201
- static void init_hypercube (int d, int nbits,
202
- int n, const float * x,
203
- float *centroids)
204
- {
205
-
206
- std::vector<float> mean (d);
201
+ static void init_hypercube(
202
+ int d,
203
+ int nbits,
204
+ int n,
205
+ const float* x,
206
+ float* centroids) {
207
+ std::vector<float> mean(d);
207
208
  for (int i = 0; i < n; i++)
208
209
  for (int j = 0; j < d; j++)
209
- mean [j] += x[i * d + j];
210
+ mean[j] += x[i * d + j];
210
211
 
211
212
  float maxm = 0;
212
213
  for (int j = 0; j < d; j++) {
213
- mean [j] /= n;
214
- if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
214
+ mean[j] /= n;
215
+ if (fabs(mean[j]) > maxm)
216
+ maxm = fabs(mean[j]);
215
217
  }
216
218
 
217
219
  for (int i = 0; i < (1 << nbits); i++) {
218
- float * cent = centroids + i * d;
220
+ float* cent = centroids + i * d;
219
221
  for (int j = 0; j < nbits; j++)
220
- cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
222
+ cent[j] = mean[j] + (((i >> j) & 1) ? 1 : -1) * maxm;
221
223
  for (int j = nbits; j < d; j++)
222
- cent[j] = mean [j];
224
+ cent[j] = mean[j];
223
225
  }
224
-
225
-
226
226
  }
227
227
 
228
- static void init_hypercube_pca (int d, int nbits,
229
- int n, const float * x,
230
- float *centroids)
231
- {
232
- PCAMatrix pca (d, nbits);
233
- pca.train (n, x);
234
-
228
+ static void init_hypercube_pca(
229
+ int d,
230
+ int nbits,
231
+ int n,
232
+ const float* x,
233
+ float* centroids) {
234
+ PCAMatrix pca(d, nbits);
235
+ pca.train(n, x);
235
236
 
236
237
  for (int i = 0; i < (1 << nbits); i++) {
237
- float * cent = centroids + i * d;
238
+ float* cent = centroids + i * d;
238
239
  for (int j = 0; j < d; j++) {
239
240
  cent[j] = pca.mean[j];
240
241
  float f = 1.0;
241
242
  for (int k = 0; k < nbits; k++)
242
- cent[j] += f *
243
- sqrt (pca.eigenvalues [k]) *
244
- (((i >> k) & 1) ? 1 : -1) *
245
- pca.PCAMat [j + k * d];
243
+ cent[j] += f * sqrt(pca.eigenvalues[k]) *
244
+ (((i >> k) & 1) ? 1 : -1) * pca.PCAMat[j + k * d];
246
245
  }
247
246
  }
248
-
249
247
  }
250
248
 
251
- void ProductQuantizer::train (int n, const float * x)
252
- {
249
+ void ProductQuantizer::train(int n, const float* x) {
253
250
  if (train_type != Train_shared) {
254
251
  train_type_t final_train_type;
255
252
  final_train_type = train_type;
@@ -257,234 +254,229 @@ void ProductQuantizer::train (int n, const float * x)
257
254
  train_type == Train_hypercube_pca) {
258
255
  if (dsub < nbits) {
259
256
  final_train_type = Train_default;
260
- printf ("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
261
- nbits, dsub);
257
+ printf("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
258
+ nbits,
259
+ dsub);
262
260
  }
263
261
  }
264
262
 
265
- float * xslice = new float[n * dsub];
266
- ScopeDeleter<float> del (xslice);
263
+ float* xslice = new float[n * dsub];
264
+ ScopeDeleter<float> del(xslice);
267
265
  for (int m = 0; m < M; m++) {
268
266
  for (int j = 0; j < n; j++)
269
- memcpy (xslice + j * dsub,
270
- x + j * d + m * dsub,
271
- dsub * sizeof(float));
267
+ memcpy(xslice + j * dsub,
268
+ x + j * d + m * dsub,
269
+ dsub * sizeof(float));
272
270
 
273
- Clustering clus (dsub, ksub, cp);
271
+ Clustering clus(dsub, ksub, cp);
274
272
 
275
273
  // we have some initialization for the centroids
276
274
  if (final_train_type != Train_default) {
277
- clus.centroids.resize (dsub * ksub);
275
+ clus.centroids.resize(dsub * ksub);
278
276
  }
279
277
 
280
278
  switch (final_train_type) {
281
- case Train_hypercube:
282
- init_hypercube (dsub, nbits, n, xslice,
283
- clus.centroids.data ());
284
- break;
285
- case Train_hypercube_pca:
286
- init_hypercube_pca (dsub, nbits, n, xslice,
287
- clus.centroids.data ());
288
- break;
289
- case Train_hot_start:
290
- memcpy (clus.centroids.data(),
291
- get_centroids (m, 0),
292
- dsub * ksub * sizeof (float));
293
- break;
294
- default: ;
279
+ case Train_hypercube:
280
+ init_hypercube(
281
+ dsub, nbits, n, xslice, clus.centroids.data());
282
+ break;
283
+ case Train_hypercube_pca:
284
+ init_hypercube_pca(
285
+ dsub, nbits, n, xslice, clus.centroids.data());
286
+ break;
287
+ case Train_hot_start:
288
+ memcpy(clus.centroids.data(),
289
+ get_centroids(m, 0),
290
+ dsub * ksub * sizeof(float));
291
+ break;
292
+ default:;
295
293
  }
296
294
 
297
- if(verbose) {
295
+ if (verbose) {
298
296
  clus.verbose = true;
299
- printf ("Training PQ slice %d/%zd\n", m, M);
297
+ printf("Training PQ slice %d/%zd\n", m, M);
300
298
  }
301
- IndexFlatL2 index (dsub);
302
- clus.train (n, xslice, assign_index ? *assign_index : index);
303
- set_params (clus.centroids.data(), m);
299
+ IndexFlatL2 index(dsub);
300
+ clus.train(n, xslice, assign_index ? *assign_index : index);
301
+ set_params(clus.centroids.data(), m);
304
302
  }
305
303
 
306
-
307
304
  } else {
305
+ Clustering clus(dsub, ksub, cp);
308
306
 
309
- Clustering clus (dsub, ksub, cp);
310
-
311
- if(verbose) {
307
+ if (verbose) {
312
308
  clus.verbose = true;
313
- printf ("Training all PQ slices at once\n");
309
+ printf("Training all PQ slices at once\n");
314
310
  }
315
311
 
316
- IndexFlatL2 index (dsub);
312
+ IndexFlatL2 index(dsub);
317
313
 
318
- clus.train (n * M, x, assign_index ? *assign_index : index);
314
+ clus.train(n * M, x, assign_index ? *assign_index : index);
319
315
  for (int m = 0; m < M; m++) {
320
- set_params (clus.centroids.data(), m);
316
+ set_params(clus.centroids.data(), m);
321
317
  }
322
-
323
318
  }
324
319
  }
325
320
 
326
- template<class PQEncoder>
327
- void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) {
328
- std::vector<float> distances(pq.ksub);
329
- PQEncoder encoder(code, pq.nbits);
330
- for (size_t m = 0; m < pq.M; m++) {
331
- float mindis = 1e20;
332
- uint64_t idxm = 0;
333
- const float * xsub = x + m * pq.dsub;
334
-
335
- fvec_L2sqr_ny(distances.data(), xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
336
-
337
- /* Find best centroid */
338
- for (size_t i = 0; i < pq.ksub; i++) {
339
- float dis = distances[i];
340
- if (dis < mindis) {
341
- mindis = dis;
342
- idxm = i;
343
- }
344
- }
321
+ template <class PQEncoder>
322
+ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
323
+ std::vector<float> distances(pq.ksub);
324
+ PQEncoder encoder(code, pq.nbits);
325
+ for (size_t m = 0; m < pq.M; m++) {
326
+ float mindis = 1e20;
327
+ uint64_t idxm = 0;
328
+ const float* xsub = x + m * pq.dsub;
329
+
330
+ fvec_L2sqr_ny(
331
+ distances.data(),
332
+ xsub,
333
+ pq.get_centroids(m, 0),
334
+ pq.dsub,
335
+ pq.ksub);
336
+
337
+ /* Find best centroid */
338
+ for (size_t i = 0; i < pq.ksub; i++) {
339
+ float dis = distances[i];
340
+ if (dis < mindis) {
341
+ mindis = dis;
342
+ idxm = i;
343
+ }
344
+ }
345
345
 
346
- encoder.encode(idxm);
347
- }
346
+ encoder.encode(idxm);
347
+ }
348
348
  }
349
349
 
350
- void ProductQuantizer::compute_code(const float * x, uint8_t * code) const {
351
- switch (nbits) {
352
- case 8:
353
- faiss::compute_code<PQEncoder8>(*this, x, code);
354
- break;
350
+ void ProductQuantizer::compute_code(const float* x, uint8_t* code) const {
351
+ switch (nbits) {
352
+ case 8:
353
+ faiss::compute_code<PQEncoder8>(*this, x, code);
354
+ break;
355
355
 
356
- case 16:
357
- faiss::compute_code<PQEncoder16>(*this, x, code);
358
- break;
356
+ case 16:
357
+ faiss::compute_code<PQEncoder16>(*this, x, code);
358
+ break;
359
359
 
360
- default:
361
- faiss::compute_code<PQEncoderGeneric>(*this, x, code);
362
- break;
363
- }
360
+ default:
361
+ faiss::compute_code<PQEncoderGeneric>(*this, x, code);
362
+ break;
363
+ }
364
364
  }
365
365
 
366
- template<class PQDecoder>
367
- void decode(const ProductQuantizer& pq, const uint8_t *code, float *x)
368
- {
369
- PQDecoder decoder(code, pq.nbits);
370
- for (size_t m = 0; m < pq.M; m++) {
371
- uint64_t c = decoder.decode();
372
- memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub);
373
- }
366
+ template <class PQDecoder>
367
+ void decode(const ProductQuantizer& pq, const uint8_t* code, float* x) {
368
+ PQDecoder decoder(code, pq.nbits);
369
+ for (size_t m = 0; m < pq.M; m++) {
370
+ uint64_t c = decoder.decode();
371
+ memcpy(x + m * pq.dsub,
372
+ pq.get_centroids(m, c),
373
+ sizeof(float) * pq.dsub);
374
+ }
374
375
  }
375
376
 
376
- void ProductQuantizer::decode (const uint8_t *code, float *x) const
377
- {
378
- switch (nbits) {
379
- case 8:
380
- faiss::decode<PQDecoder8>(*this, code, x);
381
- break;
382
-
383
- case 16:
384
- faiss::decode<PQDecoder16>(*this, code, x);
385
- break;
386
-
387
- default:
388
- faiss::decode<PQDecoderGeneric>(*this, code, x);
389
- break;
390
- }
391
- }
377
+ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
378
+ switch (nbits) {
379
+ case 8:
380
+ faiss::decode<PQDecoder8>(*this, code, x);
381
+ break;
382
+
383
+ case 16:
384
+ faiss::decode<PQDecoder16>(*this, code, x);
385
+ break;
392
386
 
387
+ default:
388
+ faiss::decode<PQDecoderGeneric>(*this, code, x);
389
+ break;
390
+ }
391
+ }
393
392
 
394
- void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
395
- {
393
+ void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
396
394
  for (size_t i = 0; i < n; i++) {
397
- this->decode (code + code_size * i, x + d * i);
395
+ this->decode(code + code_size * i, x + d * i);
398
396
  }
399
397
  }
400
398
 
399
+ void ProductQuantizer::compute_code_from_distance_table(
400
+ const float* tab,
401
+ uint8_t* code) const {
402
+ PQEncoderGeneric encoder(code, nbits);
403
+ for (size_t m = 0; m < M; m++) {
404
+ float mindis = 1e20;
405
+ uint64_t idxm = 0;
406
+
407
+ /* Find best centroid */
408
+ for (size_t j = 0; j < ksub; j++) {
409
+ float dis = *tab++;
410
+ if (dis < mindis) {
411
+ mindis = dis;
412
+ idxm = j;
413
+ }
414
+ }
401
415
 
402
- void ProductQuantizer::compute_code_from_distance_table (const float *tab,
403
- uint8_t *code) const
404
- {
405
- PQEncoderGeneric encoder(code, nbits);
406
- for (size_t m = 0; m < M; m++) {
407
- float mindis = 1e20;
408
- uint64_t idxm = 0;
409
-
410
- /* Find best centroid */
411
- for (size_t j = 0; j < ksub; j++) {
412
- float dis = *tab++;
413
- if (dis < mindis) {
414
- mindis = dis;
415
- idxm = j;
416
- }
416
+ encoder.encode(idxm);
417
417
  }
418
-
419
- encoder.encode(idxm);
420
- }
421
418
  }
422
419
 
423
- void ProductQuantizer::compute_codes_with_assign_index (
424
- const float * x,
425
- uint8_t * codes,
426
- size_t n)
427
- {
428
- FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
420
+ void ProductQuantizer::compute_codes_with_assign_index(
421
+ const float* x,
422
+ uint8_t* codes,
423
+ size_t n) {
424
+ FAISS_THROW_IF_NOT(assign_index && assign_index->d == dsub);
429
425
 
430
426
  for (size_t m = 0; m < M; m++) {
431
- assign_index->reset ();
432
- assign_index->add (ksub, get_centroids (m, 0));
427
+ assign_index->reset();
428
+ assign_index->add(ksub, get_centroids(m, 0));
433
429
  size_t bs = 65536;
434
- float * xslice = new float[bs * dsub];
435
- ScopeDeleter<float> del (xslice);
436
- idx_t *assign = new idx_t[bs];
437
- ScopeDeleter<idx_t> del2 (assign);
430
+ float* xslice = new float[bs * dsub];
431
+ ScopeDeleter<float> del(xslice);
432
+ idx_t* assign = new idx_t[bs];
433
+ ScopeDeleter<idx_t> del2(assign);
438
434
 
439
435
  for (size_t i0 = 0; i0 < n; i0 += bs) {
440
436
  size_t i1 = std::min(i0 + bs, n);
441
437
 
442
438
  for (size_t i = i0; i < i1; i++) {
443
- memcpy (xslice + (i - i0) * dsub,
444
- x + i * d + m * dsub,
445
- dsub * sizeof(float));
439
+ memcpy(xslice + (i - i0) * dsub,
440
+ x + i * d + m * dsub,
441
+ dsub * sizeof(float));
446
442
  }
447
443
 
448
- assign_index->assign (i1 - i0, xslice, assign);
444
+ assign_index->assign(i1 - i0, xslice, assign);
449
445
 
450
446
  if (nbits == 8) {
451
- uint8_t *c = codes + code_size * i0 + m;
452
- for (size_t i = i0; i < i1; i++) {
453
- *c = assign[i - i0];
454
- c += M;
455
- }
447
+ uint8_t* c = codes + code_size * i0 + m;
448
+ for (size_t i = i0; i < i1; i++) {
449
+ *c = assign[i - i0];
450
+ c += M;
451
+ }
456
452
  } else if (nbits == 16) {
457
- uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2);
458
- for (size_t i = i0; i < i1; i++) {
459
- *c = assign[i - i0];
460
- c += M;
461
- }
453
+ uint16_t* c = (uint16_t*)(codes + code_size * i0 + m * 2);
454
+ for (size_t i = i0; i < i1; i++) {
455
+ *c = assign[i - i0];
456
+ c += M;
457
+ }
462
458
  } else {
463
- for (size_t i = i0; i < i1; ++i) {
464
- uint8_t *c = codes + code_size * i + ((m * nbits) / 8);
465
- uint8_t offset = (m * nbits) % 8;
466
- uint64_t ass = assign[i - i0];
467
-
468
- PQEncoderGeneric encoder(c, nbits, offset);
469
- encoder.encode(ass);
470
- }
459
+ for (size_t i = i0; i < i1; ++i) {
460
+ uint8_t* c = codes + code_size * i + ((m * nbits) / 8);
461
+ uint8_t offset = (m * nbits) % 8;
462
+ uint64_t ass = assign[i - i0];
463
+
464
+ PQEncoderGeneric encoder(c, nbits, offset);
465
+ encoder.encode(ass);
466
+ }
471
467
  }
472
-
473
468
  }
474
469
  }
475
-
476
470
  }
477
471
 
478
- void ProductQuantizer::compute_codes (const float * x,
479
- uint8_t * codes,
480
- size_t n) const
481
- {
482
- // process by blocks to avoid using too much RAM
472
+ void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
473
+ const {
474
+ // process by blocks to avoid using too much RAM
483
475
  size_t bs = 256 * 1024;
484
476
  if (n > bs) {
485
477
  for (size_t i0 = 0; i0 < n; i0 += bs) {
486
478
  size_t i1 = std::min(i0 + bs, n);
487
- compute_codes (x + d * i0, codes + code_size * i0, i1 - i0);
479
+ compute_codes(x + d * i0, codes + code_size * i0, i1 - i0);
488
480
  }
489
481
  return;
490
482
  }
@@ -493,282 +485,300 @@ void ProductQuantizer::compute_codes (const float * x,
493
485
 
494
486
  #pragma omp parallel for
495
487
  for (int64_t i = 0; i < n; i++)
496
- compute_code (x + i * d, codes + i * code_size);
488
+ compute_code(x + i * d, codes + i * code_size);
497
489
 
498
490
  } else { // worthwile to use BLAS
499
- float *dis_tables = new float [n * ksub * M];
500
- ScopeDeleter<float> del (dis_tables);
501
- compute_distance_tables (n, x, dis_tables);
491
+ float* dis_tables = new float[n * ksub * M];
492
+ ScopeDeleter<float> del(dis_tables);
493
+ compute_distance_tables(n, x, dis_tables);
502
494
 
503
495
  #pragma omp parallel for
504
496
  for (int64_t i = 0; i < n; i++) {
505
- uint8_t * code = codes + i * code_size;
506
- const float * tab = dis_tables + i * ksub * M;
507
- compute_code_from_distance_table (tab, code);
497
+ uint8_t* code = codes + i * code_size;
498
+ const float* tab = dis_tables + i * ksub * M;
499
+ compute_code_from_distance_table(tab, code);
508
500
  }
509
501
  }
510
502
  }
511
503
 
512
-
513
- void ProductQuantizer::compute_distance_table (const float * x,
514
- float * dis_table) const
515
- {
504
+ void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
505
+ const {
516
506
  size_t m;
517
507
 
518
508
  for (m = 0; m < M; m++) {
519
- fvec_L2sqr_ny (dis_table + m * ksub,
520
- x + m * dsub,
521
- get_centroids(m, 0),
522
- dsub,
523
- ksub);
509
+ fvec_L2sqr_ny(
510
+ dis_table + m * ksub,
511
+ x + m * dsub,
512
+ get_centroids(m, 0),
513
+ dsub,
514
+ ksub);
524
515
  }
525
516
  }
526
517
 
527
- void ProductQuantizer::compute_inner_prod_table (const float * x,
528
- float * dis_table) const
529
- {
518
+ void ProductQuantizer::compute_inner_prod_table(
519
+ const float* x,
520
+ float* dis_table) const {
530
521
  size_t m;
531
522
 
532
523
  for (m = 0; m < M; m++) {
533
- fvec_inner_products_ny (dis_table + m * ksub,
534
- x + m * dsub,
535
- get_centroids(m, 0),
536
- dsub,
537
- ksub);
524
+ fvec_inner_products_ny(
525
+ dis_table + m * ksub,
526
+ x + m * dsub,
527
+ get_centroids(m, 0),
528
+ dsub,
529
+ ksub);
538
530
  }
539
531
  }
540
532
 
541
-
542
- void ProductQuantizer::compute_distance_tables (
543
- size_t nx,
544
- const float * x,
545
- float * dis_tables) const
546
- {
547
-
548
- #ifdef __AVX2__
533
+ void ProductQuantizer::compute_distance_tables(
534
+ size_t nx,
535
+ const float* x,
536
+ float* dis_tables) const {
537
+ #if defined(__AVX2__) || defined(__aarch64__)
549
538
  if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
550
539
  compute_PQ_dis_tables_dsub2(
551
- d, ksub, centroids.data(),
552
- nx, x, false, dis_tables
553
- );
540
+ d, ksub, centroids.data(), nx, x, false, dis_tables);
554
541
  } else
555
542
  #endif
556
- if (dsub < 16) {
543
+ if (dsub < 16) {
557
544
 
558
545
  #pragma omp parallel for
559
546
  for (int64_t i = 0; i < nx; i++) {
560
- compute_distance_table (x + i * d, dis_tables + i * ksub * M);
547
+ compute_distance_table(x + i * d, dis_tables + i * ksub * M);
561
548
  }
562
549
 
563
550
  } else { // use BLAS
564
551
 
565
552
  for (int m = 0; m < M; m++) {
566
- pairwise_L2sqr (dsub,
567
- nx, x + dsub * m,
568
- ksub, centroids.data() + m * dsub * ksub,
569
- dis_tables + ksub * m,
570
- d, dsub, ksub * M);
553
+ pairwise_L2sqr(
554
+ dsub,
555
+ nx,
556
+ x + dsub * m,
557
+ ksub,
558
+ centroids.data() + m * dsub * ksub,
559
+ dis_tables + ksub * m,
560
+ d,
561
+ dsub,
562
+ ksub * M);
571
563
  }
572
564
  }
573
565
  }
574
566
 
575
- void ProductQuantizer::compute_inner_prod_tables (
576
- size_t nx,
577
- const float * x,
578
- float * dis_tables) const
579
- {
580
- #ifdef __AVX2__
567
+ void ProductQuantizer::compute_inner_prod_tables(
568
+ size_t nx,
569
+ const float* x,
570
+ float* dis_tables) const {
571
+ #if defined(__AVX2__) || defined(__aarch64__)
581
572
  if (dsub == 2 && nbits < 8) {
582
573
  compute_PQ_dis_tables_dsub2(
583
- d, ksub, centroids.data(),
584
- nx, x, true, dis_tables
585
- );
574
+ d, ksub, centroids.data(), nx, x, true, dis_tables);
586
575
  } else
587
576
  #endif
588
- if (dsub < 16) {
577
+ if (dsub < 16) {
589
578
 
590
579
  #pragma omp parallel for
591
580
  for (int64_t i = 0; i < nx; i++) {
592
- compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
581
+ compute_inner_prod_table(x + i * d, dis_tables + i * ksub * M);
593
582
  }
594
583
 
595
584
  } else { // use BLAS
596
585
 
597
586
  // compute distance tables
598
587
  for (int m = 0; m < M; m++) {
599
- FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
600
- dsubi = dsub, di = d;
588
+ FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, dsubi = dsub,
589
+ di = d;
601
590
  float one = 1.0, zero = 0;
602
591
 
603
- sgemm_ ("Transposed", "Not transposed",
604
- &ksubi, &nxi, &dsubi,
605
- &one, &centroids [m * dsub * ksub], &dsubi,
606
- x + dsub * m, &di,
607
- &zero, dis_tables + ksub * m, &ldc);
592
+ sgemm_("Transposed",
593
+ "Not transposed",
594
+ &ksubi,
595
+ &nxi,
596
+ &dsubi,
597
+ &one,
598
+ &centroids[m * dsub * ksub],
599
+ &dsubi,
600
+ x + dsub * m,
601
+ &di,
602
+ &zero,
603
+ dis_tables + ksub * m,
604
+ &ldc);
608
605
  }
609
-
610
606
  }
611
607
  }
612
608
 
613
609
  template <class C>
614
- static void pq_knn_search_with_tables (
615
- const ProductQuantizer& pq,
616
- size_t nbits,
617
- const float *dis_tables,
618
- const uint8_t * codes,
619
- const size_t ncodes,
620
- HeapArray<C> * res,
621
- bool init_finalize_heap)
622
- {
610
+ static void pq_knn_search_with_tables(
611
+ const ProductQuantizer& pq,
612
+ size_t nbits,
613
+ const float* dis_tables,
614
+ const uint8_t* codes,
615
+ const size_t ncodes,
616
+ HeapArray<C>* res,
617
+ bool init_finalize_heap) {
623
618
  size_t k = res->k, nx = res->nh;
624
619
  size_t ksub = pq.ksub, M = pq.M;
625
620
 
626
-
627
621
  #pragma omp parallel for
628
622
  for (int64_t i = 0; i < nx; i++) {
629
623
  /* query preparation for asymmetric search: compute look-up tables */
630
624
  const float* dis_table = dis_tables + i * ksub * M;
631
625
 
632
626
  /* Compute distances and keep smallest values */
633
- int64_t * __restrict heap_ids = res->ids + i * k;
634
- float * __restrict heap_dis = res->val + i * k;
627
+ int64_t* __restrict heap_ids = res->ids + i * k;
628
+ float* __restrict heap_dis = res->val + i * k;
635
629
 
636
630
  if (init_finalize_heap) {
637
- heap_heapify<C> (k, heap_dis, heap_ids);
631
+ heap_heapify<C>(k, heap_dis, heap_ids);
638
632
  }
639
633
 
640
634
  switch (nbits) {
641
- case 8:
642
- pq_estimators_from_tables<uint8_t, C> (pq,
643
- codes, ncodes,
644
- dis_table,
645
- k, heap_dis, heap_ids);
646
- break;
647
-
648
- case 16:
649
- pq_estimators_from_tables<uint16_t, C> (pq,
650
- (uint16_t*)codes, ncodes,
651
- dis_table,
652
- k, heap_dis, heap_ids);
653
- break;
654
-
655
- default:
656
- pq_estimators_from_tables_generic<C> (pq,
657
- nbits,
658
- codes, ncodes,
659
- dis_table,
660
- k, heap_dis, heap_ids);
661
- break;
635
+ case 8:
636
+ pq_estimators_from_tables<uint8_t, C>(
637
+ pq, codes, ncodes, dis_table, k, heap_dis, heap_ids);
638
+ break;
639
+
640
+ case 16:
641
+ pq_estimators_from_tables<uint16_t, C>(
642
+ pq,
643
+ (uint16_t*)codes,
644
+ ncodes,
645
+ dis_table,
646
+ k,
647
+ heap_dis,
648
+ heap_ids);
649
+ break;
650
+
651
+ default:
652
+ pq_estimators_from_tables_generic<C>(
653
+ pq,
654
+ nbits,
655
+ codes,
656
+ ncodes,
657
+ dis_table,
658
+ k,
659
+ heap_dis,
660
+ heap_ids);
661
+ break;
662
662
  }
663
663
 
664
664
  if (init_finalize_heap) {
665
- heap_reorder<C> (k, heap_dis, heap_ids);
665
+ heap_reorder<C>(k, heap_dis, heap_ids);
666
666
  }
667
667
  }
668
668
  }
669
669
 
670
- void ProductQuantizer::search (const float * __restrict x,
671
- size_t nx,
672
- const uint8_t * codes,
673
- const size_t ncodes,
674
- float_maxheap_array_t * res,
675
- bool init_finalize_heap) const
676
- {
677
- FAISS_THROW_IF_NOT (nx == res->nh);
678
- std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
679
- compute_distance_tables (nx, x, dis_tables.get());
680
-
681
- pq_knn_search_with_tables<CMax<float, int64_t>> (
682
- *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
670
+ void ProductQuantizer::search(
671
+ const float* __restrict x,
672
+ size_t nx,
673
+ const uint8_t* codes,
674
+ const size_t ncodes,
675
+ float_maxheap_array_t* res,
676
+ bool init_finalize_heap) const {
677
+ FAISS_THROW_IF_NOT(nx == res->nh);
678
+ std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
679
+ compute_distance_tables(nx, x, dis_tables.get());
680
+
681
+ pq_knn_search_with_tables<CMax<float, int64_t>>(
682
+ *this,
683
+ nbits,
684
+ dis_tables.get(),
685
+ codes,
686
+ ncodes,
687
+ res,
688
+ init_finalize_heap);
683
689
  }
684
690
 
685
- void ProductQuantizer::search_ip (const float * __restrict x,
686
- size_t nx,
687
- const uint8_t * codes,
688
- const size_t ncodes,
689
- float_minheap_array_t * res,
690
- bool init_finalize_heap) const
691
- {
692
- FAISS_THROW_IF_NOT (nx == res->nh);
693
- std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
694
- compute_inner_prod_tables (nx, x, dis_tables.get());
695
-
696
- pq_knn_search_with_tables<CMin<float, int64_t> > (
697
- *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
691
+ void ProductQuantizer::search_ip(
692
+ const float* __restrict x,
693
+ size_t nx,
694
+ const uint8_t* codes,
695
+ const size_t ncodes,
696
+ float_minheap_array_t* res,
697
+ bool init_finalize_heap) const {
698
+ FAISS_THROW_IF_NOT(nx == res->nh);
699
+ std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
700
+ compute_inner_prod_tables(nx, x, dis_tables.get());
701
+
702
+ pq_knn_search_with_tables<CMin<float, int64_t>>(
703
+ *this,
704
+ nbits,
705
+ dis_tables.get(),
706
+ codes,
707
+ ncodes,
708
+ res,
709
+ init_finalize_heap);
698
710
  }
699
711
 
700
-
701
-
702
- static float sqr (float x) {
712
+ static float sqr(float x) {
703
713
  return x * x;
704
714
  }
705
715
 
706
- void ProductQuantizer::compute_sdc_table ()
707
- {
708
- sdc_table.resize (M * ksub * ksub);
709
-
710
- for (int m = 0; m < M; m++) {
716
+ void ProductQuantizer::compute_sdc_table() {
717
+ sdc_table.resize(M * ksub * ksub);
711
718
 
712
- const float *cents = centroids.data() + m * ksub * dsub;
713
- float * dis_tab = sdc_table.data() + m * ksub * ksub;
714
-
715
- // TODO optimize with BLAS
716
- for (int i = 0; i < ksub; i++) {
717
- const float *centi = cents + i * dsub;
718
- for (int j = 0; j < ksub; j++) {
719
- float accu = 0;
720
- const float *centj = cents + j * dsub;
721
- for (int k = 0; k < dsub; k++)
722
- accu += sqr (centi[k] - centj[k]);
723
- dis_tab [i + j * ksub] = accu;
724
- }
719
+ if (dsub < 4) {
720
+ #pragma omp parallel for
721
+ for (int mk = 0; mk < M * ksub; mk++) {
722
+ // allow omp to schedule in a more fine-grained way
723
+ // `collapse` is not supported in OpenMP 2.x
724
+ int m = mk / ksub;
725
+ int k = mk % ksub;
726
+ const float* cents = centroids.data() + m * ksub * dsub;
727
+ const float* centi = cents + k * dsub;
728
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
729
+ fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub);
730
+ }
731
+ } else {
732
+ // NOTE: it would disable the omp loop in pairwise_L2sqr
733
+ // but still accelerate especially when M >= 4
734
+ #pragma omp parallel for
735
+ for (int m = 0; m < M; m++) {
736
+ const float* cents = centroids.data() + m * ksub * dsub;
737
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
738
+ pairwise_L2sqr(
739
+ dsub, ksub, cents, ksub, cents, dis_tab, dsub, dsub, ksub);
725
740
  }
726
741
  }
727
742
  }
728
743
 
729
- void ProductQuantizer::search_sdc (const uint8_t * qcodes,
730
- size_t nq,
731
- const uint8_t * bcodes,
732
- const size_t nb,
733
- float_maxheap_array_t * res,
734
- bool init_finalize_heap) const
735
- {
736
- FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
737
- FAISS_THROW_IF_NOT (nbits == 8);
744
+ void ProductQuantizer::search_sdc(
745
+ const uint8_t* qcodes,
746
+ size_t nq,
747
+ const uint8_t* bcodes,
748
+ const size_t nb,
749
+ float_maxheap_array_t* res,
750
+ bool init_finalize_heap) const {
751
+ FAISS_THROW_IF_NOT(sdc_table.size() == M * ksub * ksub);
752
+ FAISS_THROW_IF_NOT(nbits == 8);
738
753
  size_t k = res->k;
739
754
 
740
-
741
755
  #pragma omp parallel for
742
756
  for (int64_t i = 0; i < nq; i++) {
743
-
744
757
  /* Compute distances and keep smallest values */
745
- idx_t * heap_ids = res->ids + i * k;
746
- float * heap_dis = res->val + i * k;
747
- const uint8_t * qcode = qcodes + i * code_size;
758
+ idx_t* heap_ids = res->ids + i * k;
759
+ float* heap_dis = res->val + i * k;
760
+ const uint8_t* qcode = qcodes + i * code_size;
748
761
 
749
762
  if (init_finalize_heap)
750
- maxheap_heapify (k, heap_dis, heap_ids);
763
+ maxheap_heapify(k, heap_dis, heap_ids);
751
764
 
752
- const uint8_t * bcode = bcodes;
765
+ const uint8_t* bcode = bcodes;
753
766
  for (size_t j = 0; j < nb; j++) {
754
767
  float dis = 0;
755
- const float * tab = sdc_table.data();
768
+ const float* tab = sdc_table.data();
756
769
  for (int m = 0; m < M; m++) {
757
770
  dis += tab[bcode[m] + qcode[m] * ksub];
758
771
  tab += ksub * ksub;
759
772
  }
760
773
  if (dis < heap_dis[0]) {
761
- maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
774
+ maxheap_replace_top(k, heap_dis, heap_ids, dis, j);
762
775
  }
763
776
  bcode += code_size;
764
777
  }
765
778
 
766
779
  if (init_finalize_heap)
767
- maxheap_reorder (k, heap_dis, heap_ids);
780
+ maxheap_reorder(k, heap_dis, heap_ids);
768
781
  }
769
-
770
782
  }
771
783
 
772
-
773
-
774
- } // namespace faiss
784
+ } // namespace faiss