faiss 0.1.7 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -9,113 +9,118 @@
9
9
 
10
10
  #include <faiss/impl/ProductQuantizer.h>
11
11
 
12
-
13
12
  #include <cstddef>
14
- #include <cstring>
15
13
  #include <cstdio>
14
+ #include <cstring>
16
15
  #include <memory>
17
16
 
18
17
  #include <algorithm>
19
18
 
20
- #include <faiss/impl/FaissAssert.h>
21
- #include <faiss/VectorTransform.h>
22
19
  #include <faiss/IndexFlat.h>
20
+ #include <faiss/VectorTransform.h>
21
+ #include <faiss/impl/FaissAssert.h>
23
22
  #include <faiss/utils/distances.h>
24
23
 
25
-
26
24
  extern "C" {
27
25
 
28
26
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
29
27
 
30
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
31
- n, FINTEGER *k, const float *alpha, const float *a,
32
- FINTEGER *lda, const float *b, FINTEGER *
33
- ldb, float *beta, float *c, FINTEGER *ldc);
34
-
28
+ int sgemm_(
29
+ const char* transa,
30
+ const char* transb,
31
+ FINTEGER* m,
32
+ FINTEGER* n,
33
+ FINTEGER* k,
34
+ const float* alpha,
35
+ const float* a,
36
+ FINTEGER* lda,
37
+ const float* b,
38
+ FINTEGER* ldb,
39
+ float* beta,
40
+ float* c,
41
+ FINTEGER* ldc);
35
42
  }
36
43
 
37
-
38
44
  namespace faiss {
39
45
 
40
-
41
46
  /* compute an estimator using look-up tables for typical values of M */
42
47
  template <typename CT, class C>
43
- void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
44
- size_t ncodes,
45
- const float * __restrict dis_table,
46
- size_t ksub,
47
- size_t k,
48
- float * heap_dis,
49
- int64_t * heap_ids)
50
- {
51
-
48
+ void pq_estimators_from_tables_Mmul4(
49
+ int M,
50
+ const CT* codes,
51
+ size_t ncodes,
52
+ const float* __restrict dis_table,
53
+ size_t ksub,
54
+ size_t k,
55
+ float* heap_dis,
56
+ int64_t* heap_ids) {
52
57
  for (size_t j = 0; j < ncodes; j++) {
53
58
  float dis = 0;
54
- const float *dt = dis_table;
59
+ const float* dt = dis_table;
55
60
 
56
- for (size_t m = 0; m < M; m+=4) {
61
+ for (size_t m = 0; m < M; m += 4) {
57
62
  float dism = 0;
58
- dism = dt[*codes++]; dt += ksub;
59
- dism += dt[*codes++]; dt += ksub;
60
- dism += dt[*codes++]; dt += ksub;
61
- dism += dt[*codes++]; dt += ksub;
63
+ dism = dt[*codes++];
64
+ dt += ksub;
65
+ dism += dt[*codes++];
66
+ dt += ksub;
67
+ dism += dt[*codes++];
68
+ dt += ksub;
69
+ dism += dt[*codes++];
70
+ dt += ksub;
62
71
  dis += dism;
63
72
  }
64
73
 
65
- if (C::cmp (heap_dis[0], dis)) {
66
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
74
+ if (C::cmp(heap_dis[0], dis)) {
75
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
67
76
  }
68
77
  }
69
78
  }
70
79
 
71
-
72
80
  template <typename CT, class C>
73
- void pq_estimators_from_tables_M4 (const CT * codes,
74
- size_t ncodes,
75
- const float * __restrict dis_table,
76
- size_t ksub,
77
- size_t k,
78
- float * heap_dis,
79
- int64_t * heap_ids)
80
- {
81
-
81
+ void pq_estimators_from_tables_M4(
82
+ const CT* codes,
83
+ size_t ncodes,
84
+ const float* __restrict dis_table,
85
+ size_t ksub,
86
+ size_t k,
87
+ float* heap_dis,
88
+ int64_t* heap_ids) {
82
89
  for (size_t j = 0; j < ncodes; j++) {
83
90
  float dis = 0;
84
- const float *dt = dis_table;
85
- dis = dt[*codes++]; dt += ksub;
86
- dis += dt[*codes++]; dt += ksub;
87
- dis += dt[*codes++]; dt += ksub;
91
+ const float* dt = dis_table;
92
+ dis = dt[*codes++];
93
+ dt += ksub;
94
+ dis += dt[*codes++];
95
+ dt += ksub;
96
+ dis += dt[*codes++];
97
+ dt += ksub;
88
98
  dis += dt[*codes++];
89
99
 
90
- if (C::cmp (heap_dis[0], dis)) {
91
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
100
+ if (C::cmp(heap_dis[0], dis)) {
101
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
92
102
  }
93
103
  }
94
104
  }
95
105
 
96
-
97
106
  template <typename CT, class C>
98
- static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
99
- const CT * codes,
100
- size_t ncodes,
101
- const float * dis_table,
102
- size_t k,
103
- float * heap_dis,
104
- int64_t * heap_ids)
105
- {
106
-
107
- if (pq.M == 4) {
108
-
109
- pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
110
- dis_table, pq.ksub, k,
111
- heap_dis, heap_ids);
107
+ static inline void pq_estimators_from_tables(
108
+ const ProductQuantizer& pq,
109
+ const CT* codes,
110
+ size_t ncodes,
111
+ const float* dis_table,
112
+ size_t k,
113
+ float* heap_dis,
114
+ int64_t* heap_ids) {
115
+ if (pq.M == 4) {
116
+ pq_estimators_from_tables_M4<CT, C>(
117
+ codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
112
118
  return;
113
119
  }
114
120
 
115
121
  if (pq.M % 4 == 0) {
116
- pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
117
- dis_table, pq.ksub, k,
118
- heap_dis, heap_ids);
122
+ pq_estimators_from_tables_Mmul4<CT, C>(
123
+ pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
119
124
  return;
120
125
  }
121
126
 
@@ -124,132 +129,124 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
124
129
  const size_t ksub = pq.ksub;
125
130
  for (size_t j = 0; j < ncodes; j++) {
126
131
  float dis = 0;
127
- const float * __restrict dt = dis_table;
132
+ const float* __restrict dt = dis_table;
128
133
  for (int m = 0; m < M; m++) {
129
134
  dis += dt[*codes++];
130
135
  dt += ksub;
131
136
  }
132
- if (C::cmp (heap_dis[0], dis)) {
133
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
137
+ if (C::cmp(heap_dis[0], dis)) {
138
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
134
139
  }
135
140
  }
136
141
  }
137
142
 
138
143
  template <class C>
139
- static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
140
- size_t nbits,
141
- const uint8_t *codes,
142
- size_t ncodes,
143
- const float *dis_table,
144
- size_t k,
145
- float *heap_dis,
146
- int64_t *heap_ids)
147
- {
148
- const size_t M = pq.M;
149
- const size_t ksub = pq.ksub;
150
- for (size_t j = 0; j < ncodes; ++j) {
151
- PQDecoderGeneric decoder(
152
- codes + j * pq.code_size, nbits
153
- );
154
- float dis = 0;
155
- const float * __restrict dt = dis_table;
156
- for (size_t m = 0; m < M; m++) {
157
- uint64_t c = decoder.decode();
158
- dis += dt[c];
159
- dt += ksub;
160
- }
144
+ static inline void pq_estimators_from_tables_generic(
145
+ const ProductQuantizer& pq,
146
+ size_t nbits,
147
+ const uint8_t* codes,
148
+ size_t ncodes,
149
+ const float* dis_table,
150
+ size_t k,
151
+ float* heap_dis,
152
+ int64_t* heap_ids) {
153
+ const size_t M = pq.M;
154
+ const size_t ksub = pq.ksub;
155
+ for (size_t j = 0; j < ncodes; ++j) {
156
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
157
+ float dis = 0;
158
+ const float* __restrict dt = dis_table;
159
+ for (size_t m = 0; m < M; m++) {
160
+ uint64_t c = decoder.decode();
161
+ dis += dt[c];
162
+ dt += ksub;
163
+ }
161
164
 
162
- if (C::cmp(heap_dis[0], dis)) {
163
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
165
+ if (C::cmp(heap_dis[0], dis)) {
166
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
167
+ }
164
168
  }
165
- }
166
169
  }
167
170
 
168
171
  /*********************************************
169
172
  * PQ implementation
170
173
  *********************************************/
171
174
 
172
-
173
-
174
- ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
175
- d(d), M(M), nbits(nbits), assign_index(nullptr)
176
- {
177
- set_derived_values ();
175
+ ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits)
176
+ : d(d), M(M), nbits(nbits), assign_index(nullptr) {
177
+ set_derived_values();
178
178
  }
179
179
 
180
- ProductQuantizer::ProductQuantizer ()
181
- : ProductQuantizer(0, 1, 0) {}
180
+ ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {}
182
181
 
183
- void ProductQuantizer::set_derived_values () {
182
+ void ProductQuantizer::set_derived_values() {
184
183
  // quite a few derived values
185
- FAISS_THROW_IF_NOT_MSG (d % M == 0, "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
184
+ FAISS_THROW_IF_NOT_MSG(
185
+ d % M == 0,
186
+ "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
186
187
  dsub = d / M;
187
188
  code_size = (nbits * M + 7) / 8;
188
189
  ksub = 1 << nbits;
189
- centroids.resize (d * ksub);
190
+ centroids.resize(d * ksub);
190
191
  verbose = false;
191
192
  train_type = Train_default;
192
193
  }
193
194
 
194
- void ProductQuantizer::set_params (const float * centroids_, int m)
195
- {
196
- memcpy (get_centroids(m, 0), centroids_,
197
- ksub * dsub * sizeof (centroids_[0]));
195
+ void ProductQuantizer::set_params(const float* centroids_, int m) {
196
+ memcpy(get_centroids(m, 0),
197
+ centroids_,
198
+ ksub * dsub * sizeof(centroids_[0]));
198
199
  }
199
200
 
200
-
201
- static void init_hypercube (int d, int nbits,
202
- int n, const float * x,
203
- float *centroids)
204
- {
205
-
206
- std::vector<float> mean (d);
201
+ static void init_hypercube(
202
+ int d,
203
+ int nbits,
204
+ int n,
205
+ const float* x,
206
+ float* centroids) {
207
+ std::vector<float> mean(d);
207
208
  for (int i = 0; i < n; i++)
208
209
  for (int j = 0; j < d; j++)
209
- mean [j] += x[i * d + j];
210
+ mean[j] += x[i * d + j];
210
211
 
211
212
  float maxm = 0;
212
213
  for (int j = 0; j < d; j++) {
213
- mean [j] /= n;
214
- if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
214
+ mean[j] /= n;
215
+ if (fabs(mean[j]) > maxm)
216
+ maxm = fabs(mean[j]);
215
217
  }
216
218
 
217
219
  for (int i = 0; i < (1 << nbits); i++) {
218
- float * cent = centroids + i * d;
220
+ float* cent = centroids + i * d;
219
221
  for (int j = 0; j < nbits; j++)
220
- cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
222
+ cent[j] = mean[j] + (((i >> j) & 1) ? 1 : -1) * maxm;
221
223
  for (int j = nbits; j < d; j++)
222
- cent[j] = mean [j];
224
+ cent[j] = mean[j];
223
225
  }
224
-
225
-
226
226
  }
227
227
 
228
- static void init_hypercube_pca (int d, int nbits,
229
- int n, const float * x,
230
- float *centroids)
231
- {
232
- PCAMatrix pca (d, nbits);
233
- pca.train (n, x);
234
-
228
+ static void init_hypercube_pca(
229
+ int d,
230
+ int nbits,
231
+ int n,
232
+ const float* x,
233
+ float* centroids) {
234
+ PCAMatrix pca(d, nbits);
235
+ pca.train(n, x);
235
236
 
236
237
  for (int i = 0; i < (1 << nbits); i++) {
237
- float * cent = centroids + i * d;
238
+ float* cent = centroids + i * d;
238
239
  for (int j = 0; j < d; j++) {
239
240
  cent[j] = pca.mean[j];
240
241
  float f = 1.0;
241
242
  for (int k = 0; k < nbits; k++)
242
- cent[j] += f *
243
- sqrt (pca.eigenvalues [k]) *
244
- (((i >> k) & 1) ? 1 : -1) *
245
- pca.PCAMat [j + k * d];
243
+ cent[j] += f * sqrt(pca.eigenvalues[k]) *
244
+ (((i >> k) & 1) ? 1 : -1) * pca.PCAMat[j + k * d];
246
245
  }
247
246
  }
248
-
249
247
  }
250
248
 
251
- void ProductQuantizer::train (int n, const float * x)
252
- {
249
+ void ProductQuantizer::train(int n, const float* x) {
253
250
  if (train_type != Train_shared) {
254
251
  train_type_t final_train_type;
255
252
  final_train_type = train_type;
@@ -257,234 +254,229 @@ void ProductQuantizer::train (int n, const float * x)
257
254
  train_type == Train_hypercube_pca) {
258
255
  if (dsub < nbits) {
259
256
  final_train_type = Train_default;
260
- printf ("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
261
- nbits, dsub);
257
+ printf("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
258
+ nbits,
259
+ dsub);
262
260
  }
263
261
  }
264
262
 
265
- float * xslice = new float[n * dsub];
266
- ScopeDeleter<float> del (xslice);
263
+ float* xslice = new float[n * dsub];
264
+ ScopeDeleter<float> del(xslice);
267
265
  for (int m = 0; m < M; m++) {
268
266
  for (int j = 0; j < n; j++)
269
- memcpy (xslice + j * dsub,
270
- x + j * d + m * dsub,
271
- dsub * sizeof(float));
267
+ memcpy(xslice + j * dsub,
268
+ x + j * d + m * dsub,
269
+ dsub * sizeof(float));
272
270
 
273
- Clustering clus (dsub, ksub, cp);
271
+ Clustering clus(dsub, ksub, cp);
274
272
 
275
273
  // we have some initialization for the centroids
276
274
  if (final_train_type != Train_default) {
277
- clus.centroids.resize (dsub * ksub);
275
+ clus.centroids.resize(dsub * ksub);
278
276
  }
279
277
 
280
278
  switch (final_train_type) {
281
- case Train_hypercube:
282
- init_hypercube (dsub, nbits, n, xslice,
283
- clus.centroids.data ());
284
- break;
285
- case Train_hypercube_pca:
286
- init_hypercube_pca (dsub, nbits, n, xslice,
287
- clus.centroids.data ());
288
- break;
289
- case Train_hot_start:
290
- memcpy (clus.centroids.data(),
291
- get_centroids (m, 0),
292
- dsub * ksub * sizeof (float));
293
- break;
294
- default: ;
279
+ case Train_hypercube:
280
+ init_hypercube(
281
+ dsub, nbits, n, xslice, clus.centroids.data());
282
+ break;
283
+ case Train_hypercube_pca:
284
+ init_hypercube_pca(
285
+ dsub, nbits, n, xslice, clus.centroids.data());
286
+ break;
287
+ case Train_hot_start:
288
+ memcpy(clus.centroids.data(),
289
+ get_centroids(m, 0),
290
+ dsub * ksub * sizeof(float));
291
+ break;
292
+ default:;
295
293
  }
296
294
 
297
- if(verbose) {
295
+ if (verbose) {
298
296
  clus.verbose = true;
299
- printf ("Training PQ slice %d/%zd\n", m, M);
297
+ printf("Training PQ slice %d/%zd\n", m, M);
300
298
  }
301
- IndexFlatL2 index (dsub);
302
- clus.train (n, xslice, assign_index ? *assign_index : index);
303
- set_params (clus.centroids.data(), m);
299
+ IndexFlatL2 index(dsub);
300
+ clus.train(n, xslice, assign_index ? *assign_index : index);
301
+ set_params(clus.centroids.data(), m);
304
302
  }
305
303
 
306
-
307
304
  } else {
305
+ Clustering clus(dsub, ksub, cp);
308
306
 
309
- Clustering clus (dsub, ksub, cp);
310
-
311
- if(verbose) {
307
+ if (verbose) {
312
308
  clus.verbose = true;
313
- printf ("Training all PQ slices at once\n");
309
+ printf("Training all PQ slices at once\n");
314
310
  }
315
311
 
316
- IndexFlatL2 index (dsub);
312
+ IndexFlatL2 index(dsub);
317
313
 
318
- clus.train (n * M, x, assign_index ? *assign_index : index);
314
+ clus.train(n * M, x, assign_index ? *assign_index : index);
319
315
  for (int m = 0; m < M; m++) {
320
- set_params (clus.centroids.data(), m);
316
+ set_params(clus.centroids.data(), m);
321
317
  }
322
-
323
318
  }
324
319
  }
325
320
 
326
- template<class PQEncoder>
327
- void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) {
328
- std::vector<float> distances(pq.ksub);
329
- PQEncoder encoder(code, pq.nbits);
330
- for (size_t m = 0; m < pq.M; m++) {
331
- float mindis = 1e20;
332
- uint64_t idxm = 0;
333
- const float * xsub = x + m * pq.dsub;
334
-
335
- fvec_L2sqr_ny(distances.data(), xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
336
-
337
- /* Find best centroid */
338
- for (size_t i = 0; i < pq.ksub; i++) {
339
- float dis = distances[i];
340
- if (dis < mindis) {
341
- mindis = dis;
342
- idxm = i;
343
- }
344
- }
321
+ template <class PQEncoder>
322
+ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
323
+ std::vector<float> distances(pq.ksub);
324
+ PQEncoder encoder(code, pq.nbits);
325
+ for (size_t m = 0; m < pq.M; m++) {
326
+ float mindis = 1e20;
327
+ uint64_t idxm = 0;
328
+ const float* xsub = x + m * pq.dsub;
329
+
330
+ fvec_L2sqr_ny(
331
+ distances.data(),
332
+ xsub,
333
+ pq.get_centroids(m, 0),
334
+ pq.dsub,
335
+ pq.ksub);
336
+
337
+ /* Find best centroid */
338
+ for (size_t i = 0; i < pq.ksub; i++) {
339
+ float dis = distances[i];
340
+ if (dis < mindis) {
341
+ mindis = dis;
342
+ idxm = i;
343
+ }
344
+ }
345
345
 
346
- encoder.encode(idxm);
347
- }
346
+ encoder.encode(idxm);
347
+ }
348
348
  }
349
349
 
350
- void ProductQuantizer::compute_code(const float * x, uint8_t * code) const {
351
- switch (nbits) {
352
- case 8:
353
- faiss::compute_code<PQEncoder8>(*this, x, code);
354
- break;
350
+ void ProductQuantizer::compute_code(const float* x, uint8_t* code) const {
351
+ switch (nbits) {
352
+ case 8:
353
+ faiss::compute_code<PQEncoder8>(*this, x, code);
354
+ break;
355
355
 
356
- case 16:
357
- faiss::compute_code<PQEncoder16>(*this, x, code);
358
- break;
356
+ case 16:
357
+ faiss::compute_code<PQEncoder16>(*this, x, code);
358
+ break;
359
359
 
360
- default:
361
- faiss::compute_code<PQEncoderGeneric>(*this, x, code);
362
- break;
363
- }
360
+ default:
361
+ faiss::compute_code<PQEncoderGeneric>(*this, x, code);
362
+ break;
363
+ }
364
364
  }
365
365
 
366
- template<class PQDecoder>
367
- void decode(const ProductQuantizer& pq, const uint8_t *code, float *x)
368
- {
369
- PQDecoder decoder(code, pq.nbits);
370
- for (size_t m = 0; m < pq.M; m++) {
371
- uint64_t c = decoder.decode();
372
- memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub);
373
- }
366
+ template <class PQDecoder>
367
+ void decode(const ProductQuantizer& pq, const uint8_t* code, float* x) {
368
+ PQDecoder decoder(code, pq.nbits);
369
+ for (size_t m = 0; m < pq.M; m++) {
370
+ uint64_t c = decoder.decode();
371
+ memcpy(x + m * pq.dsub,
372
+ pq.get_centroids(m, c),
373
+ sizeof(float) * pq.dsub);
374
+ }
374
375
  }
375
376
 
376
- void ProductQuantizer::decode (const uint8_t *code, float *x) const
377
- {
378
- switch (nbits) {
379
- case 8:
380
- faiss::decode<PQDecoder8>(*this, code, x);
381
- break;
382
-
383
- case 16:
384
- faiss::decode<PQDecoder16>(*this, code, x);
385
- break;
386
-
387
- default:
388
- faiss::decode<PQDecoderGeneric>(*this, code, x);
389
- break;
390
- }
391
- }
377
+ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
378
+ switch (nbits) {
379
+ case 8:
380
+ faiss::decode<PQDecoder8>(*this, code, x);
381
+ break;
382
+
383
+ case 16:
384
+ faiss::decode<PQDecoder16>(*this, code, x);
385
+ break;
392
386
 
387
+ default:
388
+ faiss::decode<PQDecoderGeneric>(*this, code, x);
389
+ break;
390
+ }
391
+ }
393
392
 
394
- void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
395
- {
393
+ void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
396
394
  for (size_t i = 0; i < n; i++) {
397
- this->decode (code + code_size * i, x + d * i);
395
+ this->decode(code + code_size * i, x + d * i);
398
396
  }
399
397
  }
400
398
 
399
+ void ProductQuantizer::compute_code_from_distance_table(
400
+ const float* tab,
401
+ uint8_t* code) const {
402
+ PQEncoderGeneric encoder(code, nbits);
403
+ for (size_t m = 0; m < M; m++) {
404
+ float mindis = 1e20;
405
+ uint64_t idxm = 0;
406
+
407
+ /* Find best centroid */
408
+ for (size_t j = 0; j < ksub; j++) {
409
+ float dis = *tab++;
410
+ if (dis < mindis) {
411
+ mindis = dis;
412
+ idxm = j;
413
+ }
414
+ }
401
415
 
402
- void ProductQuantizer::compute_code_from_distance_table (const float *tab,
403
- uint8_t *code) const
404
- {
405
- PQEncoderGeneric encoder(code, nbits);
406
- for (size_t m = 0; m < M; m++) {
407
- float mindis = 1e20;
408
- uint64_t idxm = 0;
409
-
410
- /* Find best centroid */
411
- for (size_t j = 0; j < ksub; j++) {
412
- float dis = *tab++;
413
- if (dis < mindis) {
414
- mindis = dis;
415
- idxm = j;
416
- }
416
+ encoder.encode(idxm);
417
417
  }
418
-
419
- encoder.encode(idxm);
420
- }
421
418
  }
422
419
 
423
- void ProductQuantizer::compute_codes_with_assign_index (
424
- const float * x,
425
- uint8_t * codes,
426
- size_t n)
427
- {
428
- FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
420
+ void ProductQuantizer::compute_codes_with_assign_index(
421
+ const float* x,
422
+ uint8_t* codes,
423
+ size_t n) {
424
+ FAISS_THROW_IF_NOT(assign_index && assign_index->d == dsub);
429
425
 
430
426
  for (size_t m = 0; m < M; m++) {
431
- assign_index->reset ();
432
- assign_index->add (ksub, get_centroids (m, 0));
427
+ assign_index->reset();
428
+ assign_index->add(ksub, get_centroids(m, 0));
433
429
  size_t bs = 65536;
434
- float * xslice = new float[bs * dsub];
435
- ScopeDeleter<float> del (xslice);
436
- idx_t *assign = new idx_t[bs];
437
- ScopeDeleter<idx_t> del2 (assign);
430
+ float* xslice = new float[bs * dsub];
431
+ ScopeDeleter<float> del(xslice);
432
+ idx_t* assign = new idx_t[bs];
433
+ ScopeDeleter<idx_t> del2(assign);
438
434
 
439
435
  for (size_t i0 = 0; i0 < n; i0 += bs) {
440
436
  size_t i1 = std::min(i0 + bs, n);
441
437
 
442
438
  for (size_t i = i0; i < i1; i++) {
443
- memcpy (xslice + (i - i0) * dsub,
444
- x + i * d + m * dsub,
445
- dsub * sizeof(float));
439
+ memcpy(xslice + (i - i0) * dsub,
440
+ x + i * d + m * dsub,
441
+ dsub * sizeof(float));
446
442
  }
447
443
 
448
- assign_index->assign (i1 - i0, xslice, assign);
444
+ assign_index->assign(i1 - i0, xslice, assign);
449
445
 
450
446
  if (nbits == 8) {
451
- uint8_t *c = codes + code_size * i0 + m;
452
- for (size_t i = i0; i < i1; i++) {
453
- *c = assign[i - i0];
454
- c += M;
455
- }
447
+ uint8_t* c = codes + code_size * i0 + m;
448
+ for (size_t i = i0; i < i1; i++) {
449
+ *c = assign[i - i0];
450
+ c += M;
451
+ }
456
452
  } else if (nbits == 16) {
457
- uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2);
458
- for (size_t i = i0; i < i1; i++) {
459
- *c = assign[i - i0];
460
- c += M;
461
- }
453
+ uint16_t* c = (uint16_t*)(codes + code_size * i0 + m * 2);
454
+ for (size_t i = i0; i < i1; i++) {
455
+ *c = assign[i - i0];
456
+ c += M;
457
+ }
462
458
  } else {
463
- for (size_t i = i0; i < i1; ++i) {
464
- uint8_t *c = codes + code_size * i + ((m * nbits) / 8);
465
- uint8_t offset = (m * nbits) % 8;
466
- uint64_t ass = assign[i - i0];
467
-
468
- PQEncoderGeneric encoder(c, nbits, offset);
469
- encoder.encode(ass);
470
- }
459
+ for (size_t i = i0; i < i1; ++i) {
460
+ uint8_t* c = codes + code_size * i + ((m * nbits) / 8);
461
+ uint8_t offset = (m * nbits) % 8;
462
+ uint64_t ass = assign[i - i0];
463
+
464
+ PQEncoderGeneric encoder(c, nbits, offset);
465
+ encoder.encode(ass);
466
+ }
471
467
  }
472
-
473
468
  }
474
469
  }
475
-
476
470
  }
477
471
 
478
- void ProductQuantizer::compute_codes (const float * x,
479
- uint8_t * codes,
480
- size_t n) const
481
- {
482
- // process by blocks to avoid using too much RAM
472
+ void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
473
+ const {
474
+ // process by blocks to avoid using too much RAM
483
475
  size_t bs = 256 * 1024;
484
476
  if (n > bs) {
485
477
  for (size_t i0 = 0; i0 < n; i0 += bs) {
486
478
  size_t i1 = std::min(i0 + bs, n);
487
- compute_codes (x + d * i0, codes + code_size * i0, i1 - i0);
479
+ compute_codes(x + d * i0, codes + code_size * i0, i1 - i0);
488
480
  }
489
481
  return;
490
482
  }
@@ -493,282 +485,300 @@ void ProductQuantizer::compute_codes (const float * x,
493
485
 
494
486
  #pragma omp parallel for
495
487
  for (int64_t i = 0; i < n; i++)
496
- compute_code (x + i * d, codes + i * code_size);
488
+ compute_code(x + i * d, codes + i * code_size);
497
489
 
498
490
  } else { // worthwile to use BLAS
499
- float *dis_tables = new float [n * ksub * M];
500
- ScopeDeleter<float> del (dis_tables);
501
- compute_distance_tables (n, x, dis_tables);
491
+ float* dis_tables = new float[n * ksub * M];
492
+ ScopeDeleter<float> del(dis_tables);
493
+ compute_distance_tables(n, x, dis_tables);
502
494
 
503
495
  #pragma omp parallel for
504
496
  for (int64_t i = 0; i < n; i++) {
505
- uint8_t * code = codes + i * code_size;
506
- const float * tab = dis_tables + i * ksub * M;
507
- compute_code_from_distance_table (tab, code);
497
+ uint8_t* code = codes + i * code_size;
498
+ const float* tab = dis_tables + i * ksub * M;
499
+ compute_code_from_distance_table(tab, code);
508
500
  }
509
501
  }
510
502
  }
511
503
 
512
-
513
- void ProductQuantizer::compute_distance_table (const float * x,
514
- float * dis_table) const
515
- {
504
+ void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
505
+ const {
516
506
  size_t m;
517
507
 
518
508
  for (m = 0; m < M; m++) {
519
- fvec_L2sqr_ny (dis_table + m * ksub,
520
- x + m * dsub,
521
- get_centroids(m, 0),
522
- dsub,
523
- ksub);
509
+ fvec_L2sqr_ny(
510
+ dis_table + m * ksub,
511
+ x + m * dsub,
512
+ get_centroids(m, 0),
513
+ dsub,
514
+ ksub);
524
515
  }
525
516
  }
526
517
 
527
- void ProductQuantizer::compute_inner_prod_table (const float * x,
528
- float * dis_table) const
529
- {
518
+ void ProductQuantizer::compute_inner_prod_table(
519
+ const float* x,
520
+ float* dis_table) const {
530
521
  size_t m;
531
522
 
532
523
  for (m = 0; m < M; m++) {
533
- fvec_inner_products_ny (dis_table + m * ksub,
534
- x + m * dsub,
535
- get_centroids(m, 0),
536
- dsub,
537
- ksub);
524
+ fvec_inner_products_ny(
525
+ dis_table + m * ksub,
526
+ x + m * dsub,
527
+ get_centroids(m, 0),
528
+ dsub,
529
+ ksub);
538
530
  }
539
531
  }
540
532
 
541
-
542
- void ProductQuantizer::compute_distance_tables (
543
- size_t nx,
544
- const float * x,
545
- float * dis_tables) const
546
- {
547
-
548
- #ifdef __AVX2__
533
+ void ProductQuantizer::compute_distance_tables(
534
+ size_t nx,
535
+ const float* x,
536
+ float* dis_tables) const {
537
+ #if defined(__AVX2__) || defined(__aarch64__)
549
538
  if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
550
539
  compute_PQ_dis_tables_dsub2(
551
- d, ksub, centroids.data(),
552
- nx, x, false, dis_tables
553
- );
540
+ d, ksub, centroids.data(), nx, x, false, dis_tables);
554
541
  } else
555
542
  #endif
556
- if (dsub < 16) {
543
+ if (dsub < 16) {
557
544
 
558
545
  #pragma omp parallel for
559
546
  for (int64_t i = 0; i < nx; i++) {
560
- compute_distance_table (x + i * d, dis_tables + i * ksub * M);
547
+ compute_distance_table(x + i * d, dis_tables + i * ksub * M);
561
548
  }
562
549
 
563
550
  } else { // use BLAS
564
551
 
565
552
  for (int m = 0; m < M; m++) {
566
- pairwise_L2sqr (dsub,
567
- nx, x + dsub * m,
568
- ksub, centroids.data() + m * dsub * ksub,
569
- dis_tables + ksub * m,
570
- d, dsub, ksub * M);
553
+ pairwise_L2sqr(
554
+ dsub,
555
+ nx,
556
+ x + dsub * m,
557
+ ksub,
558
+ centroids.data() + m * dsub * ksub,
559
+ dis_tables + ksub * m,
560
+ d,
561
+ dsub,
562
+ ksub * M);
571
563
  }
572
564
  }
573
565
  }
574
566
 
575
- void ProductQuantizer::compute_inner_prod_tables (
576
- size_t nx,
577
- const float * x,
578
- float * dis_tables) const
579
- {
580
- #ifdef __AVX2__
567
+ void ProductQuantizer::compute_inner_prod_tables(
568
+ size_t nx,
569
+ const float* x,
570
+ float* dis_tables) const {
571
+ #if defined(__AVX2__) || defined(__aarch64__)
581
572
  if (dsub == 2 && nbits < 8) {
582
573
  compute_PQ_dis_tables_dsub2(
583
- d, ksub, centroids.data(),
584
- nx, x, true, dis_tables
585
- );
574
+ d, ksub, centroids.data(), nx, x, true, dis_tables);
586
575
  } else
587
576
  #endif
588
- if (dsub < 16) {
577
+ if (dsub < 16) {
589
578
 
590
579
  #pragma omp parallel for
591
580
  for (int64_t i = 0; i < nx; i++) {
592
- compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
581
+ compute_inner_prod_table(x + i * d, dis_tables + i * ksub * M);
593
582
  }
594
583
 
595
584
  } else { // use BLAS
596
585
 
597
586
  // compute distance tables
598
587
  for (int m = 0; m < M; m++) {
599
- FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
600
- dsubi = dsub, di = d;
588
+ FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, dsubi = dsub,
589
+ di = d;
601
590
  float one = 1.0, zero = 0;
602
591
 
603
- sgemm_ ("Transposed", "Not transposed",
604
- &ksubi, &nxi, &dsubi,
605
- &one, &centroids [m * dsub * ksub], &dsubi,
606
- x + dsub * m, &di,
607
- &zero, dis_tables + ksub * m, &ldc);
592
+ sgemm_("Transposed",
593
+ "Not transposed",
594
+ &ksubi,
595
+ &nxi,
596
+ &dsubi,
597
+ &one,
598
+ &centroids[m * dsub * ksub],
599
+ &dsubi,
600
+ x + dsub * m,
601
+ &di,
602
+ &zero,
603
+ dis_tables + ksub * m,
604
+ &ldc);
608
605
  }
609
-
610
606
  }
611
607
  }
612
608
 
613
609
  template <class C>
614
- static void pq_knn_search_with_tables (
615
- const ProductQuantizer& pq,
616
- size_t nbits,
617
- const float *dis_tables,
618
- const uint8_t * codes,
619
- const size_t ncodes,
620
- HeapArray<C> * res,
621
- bool init_finalize_heap)
622
- {
610
+ static void pq_knn_search_with_tables(
611
+ const ProductQuantizer& pq,
612
+ size_t nbits,
613
+ const float* dis_tables,
614
+ const uint8_t* codes,
615
+ const size_t ncodes,
616
+ HeapArray<C>* res,
617
+ bool init_finalize_heap) {
623
618
  size_t k = res->k, nx = res->nh;
624
619
  size_t ksub = pq.ksub, M = pq.M;
625
620
 
626
-
627
621
  #pragma omp parallel for
628
622
  for (int64_t i = 0; i < nx; i++) {
629
623
  /* query preparation for asymmetric search: compute look-up tables */
630
624
  const float* dis_table = dis_tables + i * ksub * M;
631
625
 
632
626
  /* Compute distances and keep smallest values */
633
- int64_t * __restrict heap_ids = res->ids + i * k;
634
- float * __restrict heap_dis = res->val + i * k;
627
+ int64_t* __restrict heap_ids = res->ids + i * k;
628
+ float* __restrict heap_dis = res->val + i * k;
635
629
 
636
630
  if (init_finalize_heap) {
637
- heap_heapify<C> (k, heap_dis, heap_ids);
631
+ heap_heapify<C>(k, heap_dis, heap_ids);
638
632
  }
639
633
 
640
634
  switch (nbits) {
641
- case 8:
642
- pq_estimators_from_tables<uint8_t, C> (pq,
643
- codes, ncodes,
644
- dis_table,
645
- k, heap_dis, heap_ids);
646
- break;
647
-
648
- case 16:
649
- pq_estimators_from_tables<uint16_t, C> (pq,
650
- (uint16_t*)codes, ncodes,
651
- dis_table,
652
- k, heap_dis, heap_ids);
653
- break;
654
-
655
- default:
656
- pq_estimators_from_tables_generic<C> (pq,
657
- nbits,
658
- codes, ncodes,
659
- dis_table,
660
- k, heap_dis, heap_ids);
661
- break;
635
+ case 8:
636
+ pq_estimators_from_tables<uint8_t, C>(
637
+ pq, codes, ncodes, dis_table, k, heap_dis, heap_ids);
638
+ break;
639
+
640
+ case 16:
641
+ pq_estimators_from_tables<uint16_t, C>(
642
+ pq,
643
+ (uint16_t*)codes,
644
+ ncodes,
645
+ dis_table,
646
+ k,
647
+ heap_dis,
648
+ heap_ids);
649
+ break;
650
+
651
+ default:
652
+ pq_estimators_from_tables_generic<C>(
653
+ pq,
654
+ nbits,
655
+ codes,
656
+ ncodes,
657
+ dis_table,
658
+ k,
659
+ heap_dis,
660
+ heap_ids);
661
+ break;
662
662
  }
663
663
 
664
664
  if (init_finalize_heap) {
665
- heap_reorder<C> (k, heap_dis, heap_ids);
665
+ heap_reorder<C>(k, heap_dis, heap_ids);
666
666
  }
667
667
  }
668
668
  }
669
669
 
670
- void ProductQuantizer::search (const float * __restrict x,
671
- size_t nx,
672
- const uint8_t * codes,
673
- const size_t ncodes,
674
- float_maxheap_array_t * res,
675
- bool init_finalize_heap) const
676
- {
677
- FAISS_THROW_IF_NOT (nx == res->nh);
678
- std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
679
- compute_distance_tables (nx, x, dis_tables.get());
680
-
681
- pq_knn_search_with_tables<CMax<float, int64_t>> (
682
- *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
670
+ void ProductQuantizer::search(
671
+ const float* __restrict x,
672
+ size_t nx,
673
+ const uint8_t* codes,
674
+ const size_t ncodes,
675
+ float_maxheap_array_t* res,
676
+ bool init_finalize_heap) const {
677
+ FAISS_THROW_IF_NOT(nx == res->nh);
678
+ std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
679
+ compute_distance_tables(nx, x, dis_tables.get());
680
+
681
+ pq_knn_search_with_tables<CMax<float, int64_t>>(
682
+ *this,
683
+ nbits,
684
+ dis_tables.get(),
685
+ codes,
686
+ ncodes,
687
+ res,
688
+ init_finalize_heap);
683
689
  }
684
690
 
685
- void ProductQuantizer::search_ip (const float * __restrict x,
686
- size_t nx,
687
- const uint8_t * codes,
688
- const size_t ncodes,
689
- float_minheap_array_t * res,
690
- bool init_finalize_heap) const
691
- {
692
- FAISS_THROW_IF_NOT (nx == res->nh);
693
- std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
694
- compute_inner_prod_tables (nx, x, dis_tables.get());
695
-
696
- pq_knn_search_with_tables<CMin<float, int64_t> > (
697
- *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
691
+ void ProductQuantizer::search_ip(
692
+ const float* __restrict x,
693
+ size_t nx,
694
+ const uint8_t* codes,
695
+ const size_t ncodes,
696
+ float_minheap_array_t* res,
697
+ bool init_finalize_heap) const {
698
+ FAISS_THROW_IF_NOT(nx == res->nh);
699
+ std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
700
+ compute_inner_prod_tables(nx, x, dis_tables.get());
701
+
702
+ pq_knn_search_with_tables<CMin<float, int64_t>>(
703
+ *this,
704
+ nbits,
705
+ dis_tables.get(),
706
+ codes,
707
+ ncodes,
708
+ res,
709
+ init_finalize_heap);
698
710
  }
699
711
 
700
-
701
-
702
- static float sqr (float x) {
712
+ static float sqr(float x) {
703
713
  return x * x;
704
714
  }
705
715
 
706
- void ProductQuantizer::compute_sdc_table ()
707
- {
708
- sdc_table.resize (M * ksub * ksub);
709
-
710
- for (int m = 0; m < M; m++) {
716
+ void ProductQuantizer::compute_sdc_table() {
717
+ sdc_table.resize(M * ksub * ksub);
711
718
 
712
- const float *cents = centroids.data() + m * ksub * dsub;
713
- float * dis_tab = sdc_table.data() + m * ksub * ksub;
714
-
715
- // TODO optimize with BLAS
716
- for (int i = 0; i < ksub; i++) {
717
- const float *centi = cents + i * dsub;
718
- for (int j = 0; j < ksub; j++) {
719
- float accu = 0;
720
- const float *centj = cents + j * dsub;
721
- for (int k = 0; k < dsub; k++)
722
- accu += sqr (centi[k] - centj[k]);
723
- dis_tab [i + j * ksub] = accu;
724
- }
719
+ if (dsub < 4) {
720
+ #pragma omp parallel for
721
+ for (int mk = 0; mk < M * ksub; mk++) {
722
+ // allow omp to schedule in a more fine-grained way
723
+ // `collapse` is not supported in OpenMP 2.x
724
+ int m = mk / ksub;
725
+ int k = mk % ksub;
726
+ const float* cents = centroids.data() + m * ksub * dsub;
727
+ const float* centi = cents + k * dsub;
728
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
729
+ fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub);
730
+ }
731
+ } else {
732
+ // NOTE: it would disable the omp loop in pairwise_L2sqr
733
+ // but still accelerate especially when M >= 4
734
+ #pragma omp parallel for
735
+ for (int m = 0; m < M; m++) {
736
+ const float* cents = centroids.data() + m * ksub * dsub;
737
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
738
+ pairwise_L2sqr(
739
+ dsub, ksub, cents, ksub, cents, dis_tab, dsub, dsub, ksub);
725
740
  }
726
741
  }
727
742
  }
728
743
 
729
- void ProductQuantizer::search_sdc (const uint8_t * qcodes,
730
- size_t nq,
731
- const uint8_t * bcodes,
732
- const size_t nb,
733
- float_maxheap_array_t * res,
734
- bool init_finalize_heap) const
735
- {
736
- FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
737
- FAISS_THROW_IF_NOT (nbits == 8);
744
+ void ProductQuantizer::search_sdc(
745
+ const uint8_t* qcodes,
746
+ size_t nq,
747
+ const uint8_t* bcodes,
748
+ const size_t nb,
749
+ float_maxheap_array_t* res,
750
+ bool init_finalize_heap) const {
751
+ FAISS_THROW_IF_NOT(sdc_table.size() == M * ksub * ksub);
752
+ FAISS_THROW_IF_NOT(nbits == 8);
738
753
  size_t k = res->k;
739
754
 
740
-
741
755
  #pragma omp parallel for
742
756
  for (int64_t i = 0; i < nq; i++) {
743
-
744
757
  /* Compute distances and keep smallest values */
745
- idx_t * heap_ids = res->ids + i * k;
746
- float * heap_dis = res->val + i * k;
747
- const uint8_t * qcode = qcodes + i * code_size;
758
+ idx_t* heap_ids = res->ids + i * k;
759
+ float* heap_dis = res->val + i * k;
760
+ const uint8_t* qcode = qcodes + i * code_size;
748
761
 
749
762
  if (init_finalize_heap)
750
- maxheap_heapify (k, heap_dis, heap_ids);
763
+ maxheap_heapify(k, heap_dis, heap_ids);
751
764
 
752
- const uint8_t * bcode = bcodes;
765
+ const uint8_t* bcode = bcodes;
753
766
  for (size_t j = 0; j < nb; j++) {
754
767
  float dis = 0;
755
- const float * tab = sdc_table.data();
768
+ const float* tab = sdc_table.data();
756
769
  for (int m = 0; m < M; m++) {
757
770
  dis += tab[bcode[m] + qcode[m] * ksub];
758
771
  tab += ksub * ksub;
759
772
  }
760
773
  if (dis < heap_dis[0]) {
761
- maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
774
+ maxheap_replace_top(k, heap_dis, heap_ids, dis, j);
762
775
  }
763
776
  bcode += code_size;
764
777
  }
765
778
 
766
779
  if (init_finalize_heap)
767
- maxheap_reorder (k, heap_dis, heap_ids);
780
+ maxheap_reorder(k, heap_dis, heap_ids);
768
781
  }
769
-
770
782
  }
771
783
 
772
-
773
-
774
- } // namespace faiss
784
+ } // namespace faiss