faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,113 +9,118 @@
9
9
 
10
10
  #include <faiss/impl/ProductQuantizer.h>
11
11
 
12
-
13
12
  #include <cstddef>
14
- #include <cstring>
15
13
  #include <cstdio>
14
+ #include <cstring>
16
15
  #include <memory>
17
16
 
18
17
  #include <algorithm>
19
18
 
20
- #include <faiss/impl/FaissAssert.h>
21
- #include <faiss/VectorTransform.h>
22
19
  #include <faiss/IndexFlat.h>
20
+ #include <faiss/VectorTransform.h>
21
+ #include <faiss/impl/FaissAssert.h>
23
22
  #include <faiss/utils/distances.h>
24
23
 
25
-
26
24
  extern "C" {
27
25
 
28
26
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
29
27
 
30
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
31
- n, FINTEGER *k, const float *alpha, const float *a,
32
- FINTEGER *lda, const float *b, FINTEGER *
33
- ldb, float *beta, float *c, FINTEGER *ldc);
34
-
28
+ int sgemm_(
29
+ const char* transa,
30
+ const char* transb,
31
+ FINTEGER* m,
32
+ FINTEGER* n,
33
+ FINTEGER* k,
34
+ const float* alpha,
35
+ const float* a,
36
+ FINTEGER* lda,
37
+ const float* b,
38
+ FINTEGER* ldb,
39
+ float* beta,
40
+ float* c,
41
+ FINTEGER* ldc);
35
42
  }
36
43
 
37
-
38
44
  namespace faiss {
39
45
 
40
-
41
46
  /* compute an estimator using look-up tables for typical values of M */
42
47
  template <typename CT, class C>
43
- void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
44
- size_t ncodes,
45
- const float * __restrict dis_table,
46
- size_t ksub,
47
- size_t k,
48
- float * heap_dis,
49
- int64_t * heap_ids)
50
- {
51
-
48
+ void pq_estimators_from_tables_Mmul4(
49
+ int M,
50
+ const CT* codes,
51
+ size_t ncodes,
52
+ const float* __restrict dis_table,
53
+ size_t ksub,
54
+ size_t k,
55
+ float* heap_dis,
56
+ int64_t* heap_ids) {
52
57
  for (size_t j = 0; j < ncodes; j++) {
53
58
  float dis = 0;
54
- const float *dt = dis_table;
59
+ const float* dt = dis_table;
55
60
 
56
- for (size_t m = 0; m < M; m+=4) {
61
+ for (size_t m = 0; m < M; m += 4) {
57
62
  float dism = 0;
58
- dism = dt[*codes++]; dt += ksub;
59
- dism += dt[*codes++]; dt += ksub;
60
- dism += dt[*codes++]; dt += ksub;
61
- dism += dt[*codes++]; dt += ksub;
63
+ dism = dt[*codes++];
64
+ dt += ksub;
65
+ dism += dt[*codes++];
66
+ dt += ksub;
67
+ dism += dt[*codes++];
68
+ dt += ksub;
69
+ dism += dt[*codes++];
70
+ dt += ksub;
62
71
  dis += dism;
63
72
  }
64
73
 
65
- if (C::cmp (heap_dis[0], dis)) {
66
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
74
+ if (C::cmp(heap_dis[0], dis)) {
75
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
67
76
  }
68
77
  }
69
78
  }
70
79
 
71
-
72
80
  template <typename CT, class C>
73
- void pq_estimators_from_tables_M4 (const CT * codes,
74
- size_t ncodes,
75
- const float * __restrict dis_table,
76
- size_t ksub,
77
- size_t k,
78
- float * heap_dis,
79
- int64_t * heap_ids)
80
- {
81
-
81
+ void pq_estimators_from_tables_M4(
82
+ const CT* codes,
83
+ size_t ncodes,
84
+ const float* __restrict dis_table,
85
+ size_t ksub,
86
+ size_t k,
87
+ float* heap_dis,
88
+ int64_t* heap_ids) {
82
89
  for (size_t j = 0; j < ncodes; j++) {
83
90
  float dis = 0;
84
- const float *dt = dis_table;
85
- dis = dt[*codes++]; dt += ksub;
86
- dis += dt[*codes++]; dt += ksub;
87
- dis += dt[*codes++]; dt += ksub;
91
+ const float* dt = dis_table;
92
+ dis = dt[*codes++];
93
+ dt += ksub;
94
+ dis += dt[*codes++];
95
+ dt += ksub;
96
+ dis += dt[*codes++];
97
+ dt += ksub;
88
98
  dis += dt[*codes++];
89
99
 
90
- if (C::cmp (heap_dis[0], dis)) {
91
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
100
+ if (C::cmp(heap_dis[0], dis)) {
101
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
92
102
  }
93
103
  }
94
104
  }
95
105
 
96
-
97
106
  template <typename CT, class C>
98
- static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
99
- const CT * codes,
100
- size_t ncodes,
101
- const float * dis_table,
102
- size_t k,
103
- float * heap_dis,
104
- int64_t * heap_ids)
105
- {
106
-
107
- if (pq.M == 4) {
108
-
109
- pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
110
- dis_table, pq.ksub, k,
111
- heap_dis, heap_ids);
107
+ static inline void pq_estimators_from_tables(
108
+ const ProductQuantizer& pq,
109
+ const CT* codes,
110
+ size_t ncodes,
111
+ const float* dis_table,
112
+ size_t k,
113
+ float* heap_dis,
114
+ int64_t* heap_ids) {
115
+ if (pq.M == 4) {
116
+ pq_estimators_from_tables_M4<CT, C>(
117
+ codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
112
118
  return;
113
119
  }
114
120
 
115
121
  if (pq.M % 4 == 0) {
116
- pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
117
- dis_table, pq.ksub, k,
118
- heap_dis, heap_ids);
122
+ pq_estimators_from_tables_Mmul4<CT, C>(
123
+ pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
119
124
  return;
120
125
  }
121
126
 
@@ -124,132 +129,124 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
124
129
  const size_t ksub = pq.ksub;
125
130
  for (size_t j = 0; j < ncodes; j++) {
126
131
  float dis = 0;
127
- const float * __restrict dt = dis_table;
132
+ const float* __restrict dt = dis_table;
128
133
  for (int m = 0; m < M; m++) {
129
134
  dis += dt[*codes++];
130
135
  dt += ksub;
131
136
  }
132
- if (C::cmp (heap_dis[0], dis)) {
133
- heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
137
+ if (C::cmp(heap_dis[0], dis)) {
138
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
134
139
  }
135
140
  }
136
141
  }
137
142
 
138
143
  template <class C>
139
- static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
140
- size_t nbits,
141
- const uint8_t *codes,
142
- size_t ncodes,
143
- const float *dis_table,
144
- size_t k,
145
- float *heap_dis,
146
- int64_t *heap_ids)
147
- {
148
- const size_t M = pq.M;
149
- const size_t ksub = pq.ksub;
150
- for (size_t j = 0; j < ncodes; ++j) {
151
- PQDecoderGeneric decoder(
152
- codes + j * pq.code_size, nbits
153
- );
154
- float dis = 0;
155
- const float * __restrict dt = dis_table;
156
- for (size_t m = 0; m < M; m++) {
157
- uint64_t c = decoder.decode();
158
- dis += dt[c];
159
- dt += ksub;
160
- }
144
+ static inline void pq_estimators_from_tables_generic(
145
+ const ProductQuantizer& pq,
146
+ size_t nbits,
147
+ const uint8_t* codes,
148
+ size_t ncodes,
149
+ const float* dis_table,
150
+ size_t k,
151
+ float* heap_dis,
152
+ int64_t* heap_ids) {
153
+ const size_t M = pq.M;
154
+ const size_t ksub = pq.ksub;
155
+ for (size_t j = 0; j < ncodes; ++j) {
156
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
157
+ float dis = 0;
158
+ const float* __restrict dt = dis_table;
159
+ for (size_t m = 0; m < M; m++) {
160
+ uint64_t c = decoder.decode();
161
+ dis += dt[c];
162
+ dt += ksub;
163
+ }
161
164
 
162
- if (C::cmp(heap_dis[0], dis)) {
163
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
165
+ if (C::cmp(heap_dis[0], dis)) {
166
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
167
+ }
164
168
  }
165
- }
166
169
  }
167
170
 
168
171
  /*********************************************
169
172
  * PQ implementation
170
173
  *********************************************/
171
174
 
172
-
173
-
174
- ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
175
- d(d), M(M), nbits(nbits), assign_index(nullptr)
176
- {
177
- set_derived_values ();
175
+ ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits)
176
+ : d(d), M(M), nbits(nbits), assign_index(nullptr) {
177
+ set_derived_values();
178
178
  }
179
179
 
180
- ProductQuantizer::ProductQuantizer ()
181
- : ProductQuantizer(0, 1, 0) {}
180
+ ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {}
182
181
 
183
- void ProductQuantizer::set_derived_values () {
182
+ void ProductQuantizer::set_derived_values() {
184
183
  // quite a few derived values
185
- FAISS_THROW_IF_NOT_MSG (d % M == 0, "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
184
+ FAISS_THROW_IF_NOT_MSG(
185
+ d % M == 0,
186
+ "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
186
187
  dsub = d / M;
187
188
  code_size = (nbits * M + 7) / 8;
188
189
  ksub = 1 << nbits;
189
- centroids.resize (d * ksub);
190
+ centroids.resize(d * ksub);
190
191
  verbose = false;
191
192
  train_type = Train_default;
192
193
  }
193
194
 
194
- void ProductQuantizer::set_params (const float * centroids_, int m)
195
- {
196
- memcpy (get_centroids(m, 0), centroids_,
197
- ksub * dsub * sizeof (centroids_[0]));
195
+ void ProductQuantizer::set_params(const float* centroids_, int m) {
196
+ memcpy(get_centroids(m, 0),
197
+ centroids_,
198
+ ksub * dsub * sizeof(centroids_[0]));
198
199
  }
199
200
 
200
-
201
- static void init_hypercube (int d, int nbits,
202
- int n, const float * x,
203
- float *centroids)
204
- {
205
-
206
- std::vector<float> mean (d);
201
+ static void init_hypercube(
202
+ int d,
203
+ int nbits,
204
+ int n,
205
+ const float* x,
206
+ float* centroids) {
207
+ std::vector<float> mean(d);
207
208
  for (int i = 0; i < n; i++)
208
209
  for (int j = 0; j < d; j++)
209
- mean [j] += x[i * d + j];
210
+ mean[j] += x[i * d + j];
210
211
 
211
212
  float maxm = 0;
212
213
  for (int j = 0; j < d; j++) {
213
- mean [j] /= n;
214
- if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
214
+ mean[j] /= n;
215
+ if (fabs(mean[j]) > maxm)
216
+ maxm = fabs(mean[j]);
215
217
  }
216
218
 
217
219
  for (int i = 0; i < (1 << nbits); i++) {
218
- float * cent = centroids + i * d;
220
+ float* cent = centroids + i * d;
219
221
  for (int j = 0; j < nbits; j++)
220
- cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
222
+ cent[j] = mean[j] + (((i >> j) & 1) ? 1 : -1) * maxm;
221
223
  for (int j = nbits; j < d; j++)
222
- cent[j] = mean [j];
224
+ cent[j] = mean[j];
223
225
  }
224
-
225
-
226
226
  }
227
227
 
228
- static void init_hypercube_pca (int d, int nbits,
229
- int n, const float * x,
230
- float *centroids)
231
- {
232
- PCAMatrix pca (d, nbits);
233
- pca.train (n, x);
234
-
228
+ static void init_hypercube_pca(
229
+ int d,
230
+ int nbits,
231
+ int n,
232
+ const float* x,
233
+ float* centroids) {
234
+ PCAMatrix pca(d, nbits);
235
+ pca.train(n, x);
235
236
 
236
237
  for (int i = 0; i < (1 << nbits); i++) {
237
- float * cent = centroids + i * d;
238
+ float* cent = centroids + i * d;
238
239
  for (int j = 0; j < d; j++) {
239
240
  cent[j] = pca.mean[j];
240
241
  float f = 1.0;
241
242
  for (int k = 0; k < nbits; k++)
242
- cent[j] += f *
243
- sqrt (pca.eigenvalues [k]) *
244
- (((i >> k) & 1) ? 1 : -1) *
245
- pca.PCAMat [j + k * d];
243
+ cent[j] += f * sqrt(pca.eigenvalues[k]) *
244
+ (((i >> k) & 1) ? 1 : -1) * pca.PCAMat[j + k * d];
246
245
  }
247
246
  }
248
-
249
247
  }
250
248
 
251
- void ProductQuantizer::train (int n, const float * x)
252
- {
249
+ void ProductQuantizer::train(int n, const float* x) {
253
250
  if (train_type != Train_shared) {
254
251
  train_type_t final_train_type;
255
252
  final_train_type = train_type;
@@ -257,234 +254,229 @@ void ProductQuantizer::train (int n, const float * x)
257
254
  train_type == Train_hypercube_pca) {
258
255
  if (dsub < nbits) {
259
256
  final_train_type = Train_default;
260
- printf ("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
261
- nbits, dsub);
257
+ printf("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
258
+ nbits,
259
+ dsub);
262
260
  }
263
261
  }
264
262
 
265
- float * xslice = new float[n * dsub];
266
- ScopeDeleter<float> del (xslice);
263
+ float* xslice = new float[n * dsub];
264
+ ScopeDeleter<float> del(xslice);
267
265
  for (int m = 0; m < M; m++) {
268
266
  for (int j = 0; j < n; j++)
269
- memcpy (xslice + j * dsub,
270
- x + j * d + m * dsub,
271
- dsub * sizeof(float));
267
+ memcpy(xslice + j * dsub,
268
+ x + j * d + m * dsub,
269
+ dsub * sizeof(float));
272
270
 
273
- Clustering clus (dsub, ksub, cp);
271
+ Clustering clus(dsub, ksub, cp);
274
272
 
275
273
  // we have some initialization for the centroids
276
274
  if (final_train_type != Train_default) {
277
- clus.centroids.resize (dsub * ksub);
275
+ clus.centroids.resize(dsub * ksub);
278
276
  }
279
277
 
280
278
  switch (final_train_type) {
281
- case Train_hypercube:
282
- init_hypercube (dsub, nbits, n, xslice,
283
- clus.centroids.data ());
284
- break;
285
- case Train_hypercube_pca:
286
- init_hypercube_pca (dsub, nbits, n, xslice,
287
- clus.centroids.data ());
288
- break;
289
- case Train_hot_start:
290
- memcpy (clus.centroids.data(),
291
- get_centroids (m, 0),
292
- dsub * ksub * sizeof (float));
293
- break;
294
- default: ;
279
+ case Train_hypercube:
280
+ init_hypercube(
281
+ dsub, nbits, n, xslice, clus.centroids.data());
282
+ break;
283
+ case Train_hypercube_pca:
284
+ init_hypercube_pca(
285
+ dsub, nbits, n, xslice, clus.centroids.data());
286
+ break;
287
+ case Train_hot_start:
288
+ memcpy(clus.centroids.data(),
289
+ get_centroids(m, 0),
290
+ dsub * ksub * sizeof(float));
291
+ break;
292
+ default:;
295
293
  }
296
294
 
297
- if(verbose) {
295
+ if (verbose) {
298
296
  clus.verbose = true;
299
- printf ("Training PQ slice %d/%zd\n", m, M);
297
+ printf("Training PQ slice %d/%zd\n", m, M);
300
298
  }
301
- IndexFlatL2 index (dsub);
302
- clus.train (n, xslice, assign_index ? *assign_index : index);
303
- set_params (clus.centroids.data(), m);
299
+ IndexFlatL2 index(dsub);
300
+ clus.train(n, xslice, assign_index ? *assign_index : index);
301
+ set_params(clus.centroids.data(), m);
304
302
  }
305
303
 
306
-
307
304
  } else {
305
+ Clustering clus(dsub, ksub, cp);
308
306
 
309
- Clustering clus (dsub, ksub, cp);
310
-
311
- if(verbose) {
307
+ if (verbose) {
312
308
  clus.verbose = true;
313
- printf ("Training all PQ slices at once\n");
309
+ printf("Training all PQ slices at once\n");
314
310
  }
315
311
 
316
- IndexFlatL2 index (dsub);
312
+ IndexFlatL2 index(dsub);
317
313
 
318
- clus.train (n * M, x, assign_index ? *assign_index : index);
314
+ clus.train(n * M, x, assign_index ? *assign_index : index);
319
315
  for (int m = 0; m < M; m++) {
320
- set_params (clus.centroids.data(), m);
316
+ set_params(clus.centroids.data(), m);
321
317
  }
322
-
323
318
  }
324
319
  }
325
320
 
326
- template<class PQEncoder>
327
- void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) {
328
- std::vector<float> distances(pq.ksub);
329
- PQEncoder encoder(code, pq.nbits);
330
- for (size_t m = 0; m < pq.M; m++) {
331
- float mindis = 1e20;
332
- uint64_t idxm = 0;
333
- const float * xsub = x + m * pq.dsub;
334
-
335
- fvec_L2sqr_ny(distances.data(), xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
336
-
337
- /* Find best centroid */
338
- for (size_t i = 0; i < pq.ksub; i++) {
339
- float dis = distances[i];
340
- if (dis < mindis) {
341
- mindis = dis;
342
- idxm = i;
343
- }
344
- }
321
+ template <class PQEncoder>
322
+ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
323
+ std::vector<float> distances(pq.ksub);
324
+ PQEncoder encoder(code, pq.nbits);
325
+ for (size_t m = 0; m < pq.M; m++) {
326
+ float mindis = 1e20;
327
+ uint64_t idxm = 0;
328
+ const float* xsub = x + m * pq.dsub;
329
+
330
+ fvec_L2sqr_ny(
331
+ distances.data(),
332
+ xsub,
333
+ pq.get_centroids(m, 0),
334
+ pq.dsub,
335
+ pq.ksub);
336
+
337
+ /* Find best centroid */
338
+ for (size_t i = 0; i < pq.ksub; i++) {
339
+ float dis = distances[i];
340
+ if (dis < mindis) {
341
+ mindis = dis;
342
+ idxm = i;
343
+ }
344
+ }
345
345
 
346
- encoder.encode(idxm);
347
- }
346
+ encoder.encode(idxm);
347
+ }
348
348
  }
349
349
 
350
- void ProductQuantizer::compute_code(const float * x, uint8_t * code) const {
351
- switch (nbits) {
352
- case 8:
353
- faiss::compute_code<PQEncoder8>(*this, x, code);
354
- break;
350
+ void ProductQuantizer::compute_code(const float* x, uint8_t* code) const {
351
+ switch (nbits) {
352
+ case 8:
353
+ faiss::compute_code<PQEncoder8>(*this, x, code);
354
+ break;
355
355
 
356
- case 16:
357
- faiss::compute_code<PQEncoder16>(*this, x, code);
358
- break;
356
+ case 16:
357
+ faiss::compute_code<PQEncoder16>(*this, x, code);
358
+ break;
359
359
 
360
- default:
361
- faiss::compute_code<PQEncoderGeneric>(*this, x, code);
362
- break;
363
- }
360
+ default:
361
+ faiss::compute_code<PQEncoderGeneric>(*this, x, code);
362
+ break;
363
+ }
364
364
  }
365
365
 
366
- template<class PQDecoder>
367
- void decode(const ProductQuantizer& pq, const uint8_t *code, float *x)
368
- {
369
- PQDecoder decoder(code, pq.nbits);
370
- for (size_t m = 0; m < pq.M; m++) {
371
- uint64_t c = decoder.decode();
372
- memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub);
373
- }
366
+ template <class PQDecoder>
367
+ void decode(const ProductQuantizer& pq, const uint8_t* code, float* x) {
368
+ PQDecoder decoder(code, pq.nbits);
369
+ for (size_t m = 0; m < pq.M; m++) {
370
+ uint64_t c = decoder.decode();
371
+ memcpy(x + m * pq.dsub,
372
+ pq.get_centroids(m, c),
373
+ sizeof(float) * pq.dsub);
374
+ }
374
375
  }
375
376
 
376
- void ProductQuantizer::decode (const uint8_t *code, float *x) const
377
- {
378
- switch (nbits) {
379
- case 8:
380
- faiss::decode<PQDecoder8>(*this, code, x);
381
- break;
382
-
383
- case 16:
384
- faiss::decode<PQDecoder16>(*this, code, x);
385
- break;
386
-
387
- default:
388
- faiss::decode<PQDecoderGeneric>(*this, code, x);
389
- break;
390
- }
391
- }
377
+ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
378
+ switch (nbits) {
379
+ case 8:
380
+ faiss::decode<PQDecoder8>(*this, code, x);
381
+ break;
382
+
383
+ case 16:
384
+ faiss::decode<PQDecoder16>(*this, code, x);
385
+ break;
392
386
 
387
+ default:
388
+ faiss::decode<PQDecoderGeneric>(*this, code, x);
389
+ break;
390
+ }
391
+ }
393
392
 
394
- void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
395
- {
393
+ void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
396
394
  for (size_t i = 0; i < n; i++) {
397
- this->decode (code + code_size * i, x + d * i);
395
+ this->decode(code + code_size * i, x + d * i);
398
396
  }
399
397
  }
400
398
 
399
+ void ProductQuantizer::compute_code_from_distance_table(
400
+ const float* tab,
401
+ uint8_t* code) const {
402
+ PQEncoderGeneric encoder(code, nbits);
403
+ for (size_t m = 0; m < M; m++) {
404
+ float mindis = 1e20;
405
+ uint64_t idxm = 0;
406
+
407
+ /* Find best centroid */
408
+ for (size_t j = 0; j < ksub; j++) {
409
+ float dis = *tab++;
410
+ if (dis < mindis) {
411
+ mindis = dis;
412
+ idxm = j;
413
+ }
414
+ }
401
415
 
402
- void ProductQuantizer::compute_code_from_distance_table (const float *tab,
403
- uint8_t *code) const
404
- {
405
- PQEncoderGeneric encoder(code, nbits);
406
- for (size_t m = 0; m < M; m++) {
407
- float mindis = 1e20;
408
- uint64_t idxm = 0;
409
-
410
- /* Find best centroid */
411
- for (size_t j = 0; j < ksub; j++) {
412
- float dis = *tab++;
413
- if (dis < mindis) {
414
- mindis = dis;
415
- idxm = j;
416
- }
416
+ encoder.encode(idxm);
417
417
  }
418
-
419
- encoder.encode(idxm);
420
- }
421
418
  }
422
419
 
423
- void ProductQuantizer::compute_codes_with_assign_index (
424
- const float * x,
425
- uint8_t * codes,
426
- size_t n)
427
- {
428
- FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
420
+ void ProductQuantizer::compute_codes_with_assign_index(
421
+ const float* x,
422
+ uint8_t* codes,
423
+ size_t n) {
424
+ FAISS_THROW_IF_NOT(assign_index && assign_index->d == dsub);
429
425
 
430
426
  for (size_t m = 0; m < M; m++) {
431
- assign_index->reset ();
432
- assign_index->add (ksub, get_centroids (m, 0));
427
+ assign_index->reset();
428
+ assign_index->add(ksub, get_centroids(m, 0));
433
429
  size_t bs = 65536;
434
- float * xslice = new float[bs * dsub];
435
- ScopeDeleter<float> del (xslice);
436
- idx_t *assign = new idx_t[bs];
437
- ScopeDeleter<idx_t> del2 (assign);
430
+ float* xslice = new float[bs * dsub];
431
+ ScopeDeleter<float> del(xslice);
432
+ idx_t* assign = new idx_t[bs];
433
+ ScopeDeleter<idx_t> del2(assign);
438
434
 
439
435
  for (size_t i0 = 0; i0 < n; i0 += bs) {
440
436
  size_t i1 = std::min(i0 + bs, n);
441
437
 
442
438
  for (size_t i = i0; i < i1; i++) {
443
- memcpy (xslice + (i - i0) * dsub,
444
- x + i * d + m * dsub,
445
- dsub * sizeof(float));
439
+ memcpy(xslice + (i - i0) * dsub,
440
+ x + i * d + m * dsub,
441
+ dsub * sizeof(float));
446
442
  }
447
443
 
448
- assign_index->assign (i1 - i0, xslice, assign);
444
+ assign_index->assign(i1 - i0, xslice, assign);
449
445
 
450
446
  if (nbits == 8) {
451
- uint8_t *c = codes + code_size * i0 + m;
452
- for (size_t i = i0; i < i1; i++) {
453
- *c = assign[i - i0];
454
- c += M;
455
- }
447
+ uint8_t* c = codes + code_size * i0 + m;
448
+ for (size_t i = i0; i < i1; i++) {
449
+ *c = assign[i - i0];
450
+ c += M;
451
+ }
456
452
  } else if (nbits == 16) {
457
- uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2);
458
- for (size_t i = i0; i < i1; i++) {
459
- *c = assign[i - i0];
460
- c += M;
461
- }
453
+ uint16_t* c = (uint16_t*)(codes + code_size * i0 + m * 2);
454
+ for (size_t i = i0; i < i1; i++) {
455
+ *c = assign[i - i0];
456
+ c += M;
457
+ }
462
458
  } else {
463
- for (size_t i = i0; i < i1; ++i) {
464
- uint8_t *c = codes + code_size * i + ((m * nbits) / 8);
465
- uint8_t offset = (m * nbits) % 8;
466
- uint64_t ass = assign[i - i0];
467
-
468
- PQEncoderGeneric encoder(c, nbits, offset);
469
- encoder.encode(ass);
470
- }
459
+ for (size_t i = i0; i < i1; ++i) {
460
+ uint8_t* c = codes + code_size * i + ((m * nbits) / 8);
461
+ uint8_t offset = (m * nbits) % 8;
462
+ uint64_t ass = assign[i - i0];
463
+
464
+ PQEncoderGeneric encoder(c, nbits, offset);
465
+ encoder.encode(ass);
466
+ }
471
467
  }
472
-
473
468
  }
474
469
  }
475
-
476
470
  }
477
471
 
478
- void ProductQuantizer::compute_codes (const float * x,
479
- uint8_t * codes,
480
- size_t n) const
481
- {
482
- // process by blocks to avoid using too much RAM
472
+ void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
473
+ const {
474
+ // process by blocks to avoid using too much RAM
483
475
  size_t bs = 256 * 1024;
484
476
  if (n > bs) {
485
477
  for (size_t i0 = 0; i0 < n; i0 += bs) {
486
478
  size_t i1 = std::min(i0 + bs, n);
487
- compute_codes (x + d * i0, codes + code_size * i0, i1 - i0);
479
+ compute_codes(x + d * i0, codes + code_size * i0, i1 - i0);
488
480
  }
489
481
  return;
490
482
  }
@@ -493,282 +485,300 @@ void ProductQuantizer::compute_codes (const float * x,
493
485
 
494
486
  #pragma omp parallel for
495
487
  for (int64_t i = 0; i < n; i++)
496
- compute_code (x + i * d, codes + i * code_size);
488
+ compute_code(x + i * d, codes + i * code_size);
497
489
 
498
490
  } else { // worthwile to use BLAS
499
- float *dis_tables = new float [n * ksub * M];
500
- ScopeDeleter<float> del (dis_tables);
501
- compute_distance_tables (n, x, dis_tables);
491
+ float* dis_tables = new float[n * ksub * M];
492
+ ScopeDeleter<float> del(dis_tables);
493
+ compute_distance_tables(n, x, dis_tables);
502
494
 
503
495
  #pragma omp parallel for
504
496
  for (int64_t i = 0; i < n; i++) {
505
- uint8_t * code = codes + i * code_size;
506
- const float * tab = dis_tables + i * ksub * M;
507
- compute_code_from_distance_table (tab, code);
497
+ uint8_t* code = codes + i * code_size;
498
+ const float* tab = dis_tables + i * ksub * M;
499
+ compute_code_from_distance_table(tab, code);
508
500
  }
509
501
  }
510
502
  }
511
503
 
512
-
513
- void ProductQuantizer::compute_distance_table (const float * x,
514
- float * dis_table) const
515
- {
504
+ void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
505
+ const {
516
506
  size_t m;
517
507
 
518
508
  for (m = 0; m < M; m++) {
519
- fvec_L2sqr_ny (dis_table + m * ksub,
520
- x + m * dsub,
521
- get_centroids(m, 0),
522
- dsub,
523
- ksub);
509
+ fvec_L2sqr_ny(
510
+ dis_table + m * ksub,
511
+ x + m * dsub,
512
+ get_centroids(m, 0),
513
+ dsub,
514
+ ksub);
524
515
  }
525
516
  }
526
517
 
527
- void ProductQuantizer::compute_inner_prod_table (const float * x,
528
- float * dis_table) const
529
- {
518
+ void ProductQuantizer::compute_inner_prod_table(
519
+ const float* x,
520
+ float* dis_table) const {
530
521
  size_t m;
531
522
 
532
523
  for (m = 0; m < M; m++) {
533
- fvec_inner_products_ny (dis_table + m * ksub,
534
- x + m * dsub,
535
- get_centroids(m, 0),
536
- dsub,
537
- ksub);
524
+ fvec_inner_products_ny(
525
+ dis_table + m * ksub,
526
+ x + m * dsub,
527
+ get_centroids(m, 0),
528
+ dsub,
529
+ ksub);
538
530
  }
539
531
  }
540
532
 
541
-
542
- void ProductQuantizer::compute_distance_tables (
543
- size_t nx,
544
- const float * x,
545
- float * dis_tables) const
546
- {
547
-
548
- #ifdef __AVX2__
533
+ void ProductQuantizer::compute_distance_tables(
534
+ size_t nx,
535
+ const float* x,
536
+ float* dis_tables) const {
537
+ #if defined(__AVX2__) || defined(__aarch64__)
549
538
  if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
550
539
  compute_PQ_dis_tables_dsub2(
551
- d, ksub, centroids.data(),
552
- nx, x, false, dis_tables
553
- );
540
+ d, ksub, centroids.data(), nx, x, false, dis_tables);
554
541
  } else
555
542
  #endif
556
- if (dsub < 16) {
543
+ if (dsub < 16) {
557
544
 
558
545
  #pragma omp parallel for
559
546
  for (int64_t i = 0; i < nx; i++) {
560
- compute_distance_table (x + i * d, dis_tables + i * ksub * M);
547
+ compute_distance_table(x + i * d, dis_tables + i * ksub * M);
561
548
  }
562
549
 
563
550
  } else { // use BLAS
564
551
 
565
552
  for (int m = 0; m < M; m++) {
566
- pairwise_L2sqr (dsub,
567
- nx, x + dsub * m,
568
- ksub, centroids.data() + m * dsub * ksub,
569
- dis_tables + ksub * m,
570
- d, dsub, ksub * M);
553
+ pairwise_L2sqr(
554
+ dsub,
555
+ nx,
556
+ x + dsub * m,
557
+ ksub,
558
+ centroids.data() + m * dsub * ksub,
559
+ dis_tables + ksub * m,
560
+ d,
561
+ dsub,
562
+ ksub * M);
571
563
  }
572
564
  }
573
565
  }
574
566
 
575
- void ProductQuantizer::compute_inner_prod_tables (
576
- size_t nx,
577
- const float * x,
578
- float * dis_tables) const
579
- {
580
- #ifdef __AVX2__
567
+ void ProductQuantizer::compute_inner_prod_tables(
568
+ size_t nx,
569
+ const float* x,
570
+ float* dis_tables) const {
571
+ #if defined(__AVX2__) || defined(__aarch64__)
581
572
  if (dsub == 2 && nbits < 8) {
582
573
  compute_PQ_dis_tables_dsub2(
583
- d, ksub, centroids.data(),
584
- nx, x, true, dis_tables
585
- );
574
+ d, ksub, centroids.data(), nx, x, true, dis_tables);
586
575
  } else
587
576
  #endif
588
- if (dsub < 16) {
577
+ if (dsub < 16) {
589
578
 
590
579
  #pragma omp parallel for
591
580
  for (int64_t i = 0; i < nx; i++) {
592
- compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
581
+ compute_inner_prod_table(x + i * d, dis_tables + i * ksub * M);
593
582
  }
594
583
 
595
584
  } else { // use BLAS
596
585
 
597
586
  // compute distance tables
598
587
  for (int m = 0; m < M; m++) {
599
- FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
600
- dsubi = dsub, di = d;
588
+ FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, dsubi = dsub,
589
+ di = d;
601
590
  float one = 1.0, zero = 0;
602
591
 
603
- sgemm_ ("Transposed", "Not transposed",
604
- &ksubi, &nxi, &dsubi,
605
- &one, &centroids [m * dsub * ksub], &dsubi,
606
- x + dsub * m, &di,
607
- &zero, dis_tables + ksub * m, &ldc);
592
+ sgemm_("Transposed",
593
+ "Not transposed",
594
+ &ksubi,
595
+ &nxi,
596
+ &dsubi,
597
+ &one,
598
+ &centroids[m * dsub * ksub],
599
+ &dsubi,
600
+ x + dsub * m,
601
+ &di,
602
+ &zero,
603
+ dis_tables + ksub * m,
604
+ &ldc);
608
605
  }
609
-
610
606
  }
611
607
  }
612
608
 
613
609
  template <class C>
614
- static void pq_knn_search_with_tables (
615
- const ProductQuantizer& pq,
616
- size_t nbits,
617
- const float *dis_tables,
618
- const uint8_t * codes,
619
- const size_t ncodes,
620
- HeapArray<C> * res,
621
- bool init_finalize_heap)
622
- {
610
+ static void pq_knn_search_with_tables(
611
+ const ProductQuantizer& pq,
612
+ size_t nbits,
613
+ const float* dis_tables,
614
+ const uint8_t* codes,
615
+ const size_t ncodes,
616
+ HeapArray<C>* res,
617
+ bool init_finalize_heap) {
623
618
  size_t k = res->k, nx = res->nh;
624
619
  size_t ksub = pq.ksub, M = pq.M;
625
620
 
626
-
627
621
  #pragma omp parallel for
628
622
  for (int64_t i = 0; i < nx; i++) {
629
623
  /* query preparation for asymmetric search: compute look-up tables */
630
624
  const float* dis_table = dis_tables + i * ksub * M;
631
625
 
632
626
  /* Compute distances and keep smallest values */
633
- int64_t * __restrict heap_ids = res->ids + i * k;
634
- float * __restrict heap_dis = res->val + i * k;
627
+ int64_t* __restrict heap_ids = res->ids + i * k;
628
+ float* __restrict heap_dis = res->val + i * k;
635
629
 
636
630
  if (init_finalize_heap) {
637
- heap_heapify<C> (k, heap_dis, heap_ids);
631
+ heap_heapify<C>(k, heap_dis, heap_ids);
638
632
  }
639
633
 
640
634
  switch (nbits) {
641
- case 8:
642
- pq_estimators_from_tables<uint8_t, C> (pq,
643
- codes, ncodes,
644
- dis_table,
645
- k, heap_dis, heap_ids);
646
- break;
647
-
648
- case 16:
649
- pq_estimators_from_tables<uint16_t, C> (pq,
650
- (uint16_t*)codes, ncodes,
651
- dis_table,
652
- k, heap_dis, heap_ids);
653
- break;
654
-
655
- default:
656
- pq_estimators_from_tables_generic<C> (pq,
657
- nbits,
658
- codes, ncodes,
659
- dis_table,
660
- k, heap_dis, heap_ids);
661
- break;
635
+ case 8:
636
+ pq_estimators_from_tables<uint8_t, C>(
637
+ pq, codes, ncodes, dis_table, k, heap_dis, heap_ids);
638
+ break;
639
+
640
+ case 16:
641
+ pq_estimators_from_tables<uint16_t, C>(
642
+ pq,
643
+ (uint16_t*)codes,
644
+ ncodes,
645
+ dis_table,
646
+ k,
647
+ heap_dis,
648
+ heap_ids);
649
+ break;
650
+
651
+ default:
652
+ pq_estimators_from_tables_generic<C>(
653
+ pq,
654
+ nbits,
655
+ codes,
656
+ ncodes,
657
+ dis_table,
658
+ k,
659
+ heap_dis,
660
+ heap_ids);
661
+ break;
662
662
  }
663
663
 
664
664
  if (init_finalize_heap) {
665
- heap_reorder<C> (k, heap_dis, heap_ids);
665
+ heap_reorder<C>(k, heap_dis, heap_ids);
666
666
  }
667
667
  }
668
668
  }
669
669
 
670
- void ProductQuantizer::search (const float * __restrict x,
671
- size_t nx,
672
- const uint8_t * codes,
673
- const size_t ncodes,
674
- float_maxheap_array_t * res,
675
- bool init_finalize_heap) const
676
- {
677
- FAISS_THROW_IF_NOT (nx == res->nh);
678
- std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
679
- compute_distance_tables (nx, x, dis_tables.get());
680
-
681
- pq_knn_search_with_tables<CMax<float, int64_t>> (
682
- *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
670
+ void ProductQuantizer::search(
671
+ const float* __restrict x,
672
+ size_t nx,
673
+ const uint8_t* codes,
674
+ const size_t ncodes,
675
+ float_maxheap_array_t* res,
676
+ bool init_finalize_heap) const {
677
+ FAISS_THROW_IF_NOT(nx == res->nh);
678
+ std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
679
+ compute_distance_tables(nx, x, dis_tables.get());
680
+
681
+ pq_knn_search_with_tables<CMax<float, int64_t>>(
682
+ *this,
683
+ nbits,
684
+ dis_tables.get(),
685
+ codes,
686
+ ncodes,
687
+ res,
688
+ init_finalize_heap);
683
689
  }
684
690
 
685
- void ProductQuantizer::search_ip (const float * __restrict x,
686
- size_t nx,
687
- const uint8_t * codes,
688
- const size_t ncodes,
689
- float_minheap_array_t * res,
690
- bool init_finalize_heap) const
691
- {
692
- FAISS_THROW_IF_NOT (nx == res->nh);
693
- std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
694
- compute_inner_prod_tables (nx, x, dis_tables.get());
695
-
696
- pq_knn_search_with_tables<CMin<float, int64_t> > (
697
- *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
691
+ void ProductQuantizer::search_ip(
692
+ const float* __restrict x,
693
+ size_t nx,
694
+ const uint8_t* codes,
695
+ const size_t ncodes,
696
+ float_minheap_array_t* res,
697
+ bool init_finalize_heap) const {
698
+ FAISS_THROW_IF_NOT(nx == res->nh);
699
+ std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
700
+ compute_inner_prod_tables(nx, x, dis_tables.get());
701
+
702
+ pq_knn_search_with_tables<CMin<float, int64_t>>(
703
+ *this,
704
+ nbits,
705
+ dis_tables.get(),
706
+ codes,
707
+ ncodes,
708
+ res,
709
+ init_finalize_heap);
698
710
  }
699
711
 
700
-
701
-
702
- static float sqr (float x) {
712
+ static float sqr(float x) {
703
713
  return x * x;
704
714
  }
705
715
 
706
- void ProductQuantizer::compute_sdc_table ()
707
- {
708
- sdc_table.resize (M * ksub * ksub);
709
-
710
- for (int m = 0; m < M; m++) {
716
+ void ProductQuantizer::compute_sdc_table() {
717
+ sdc_table.resize(M * ksub * ksub);
711
718
 
712
- const float *cents = centroids.data() + m * ksub * dsub;
713
- float * dis_tab = sdc_table.data() + m * ksub * ksub;
714
-
715
- // TODO optimize with BLAS
716
- for (int i = 0; i < ksub; i++) {
717
- const float *centi = cents + i * dsub;
718
- for (int j = 0; j < ksub; j++) {
719
- float accu = 0;
720
- const float *centj = cents + j * dsub;
721
- for (int k = 0; k < dsub; k++)
722
- accu += sqr (centi[k] - centj[k]);
723
- dis_tab [i + j * ksub] = accu;
724
- }
719
+ if (dsub < 4) {
720
+ #pragma omp parallel for
721
+ for (int mk = 0; mk < M * ksub; mk++) {
722
+ // allow omp to schedule in a more fine-grained way
723
+ // `collapse` is not supported in OpenMP 2.x
724
+ int m = mk / ksub;
725
+ int k = mk % ksub;
726
+ const float* cents = centroids.data() + m * ksub * dsub;
727
+ const float* centi = cents + k * dsub;
728
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
729
+ fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub);
730
+ }
731
+ } else {
732
+ // NOTE: it would disable the omp loop in pairwise_L2sqr
733
+ // but still accelerate especially when M >= 4
734
+ #pragma omp parallel for
735
+ for (int m = 0; m < M; m++) {
736
+ const float* cents = centroids.data() + m * ksub * dsub;
737
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
738
+ pairwise_L2sqr(
739
+ dsub, ksub, cents, ksub, cents, dis_tab, dsub, dsub, ksub);
725
740
  }
726
741
  }
727
742
  }
728
743
 
729
- void ProductQuantizer::search_sdc (const uint8_t * qcodes,
730
- size_t nq,
731
- const uint8_t * bcodes,
732
- const size_t nb,
733
- float_maxheap_array_t * res,
734
- bool init_finalize_heap) const
735
- {
736
- FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
737
- FAISS_THROW_IF_NOT (nbits == 8);
744
+ void ProductQuantizer::search_sdc(
745
+ const uint8_t* qcodes,
746
+ size_t nq,
747
+ const uint8_t* bcodes,
748
+ const size_t nb,
749
+ float_maxheap_array_t* res,
750
+ bool init_finalize_heap) const {
751
+ FAISS_THROW_IF_NOT(sdc_table.size() == M * ksub * ksub);
752
+ FAISS_THROW_IF_NOT(nbits == 8);
738
753
  size_t k = res->k;
739
754
 
740
-
741
755
  #pragma omp parallel for
742
756
  for (int64_t i = 0; i < nq; i++) {
743
-
744
757
  /* Compute distances and keep smallest values */
745
- idx_t * heap_ids = res->ids + i * k;
746
- float * heap_dis = res->val + i * k;
747
- const uint8_t * qcode = qcodes + i * code_size;
758
+ idx_t* heap_ids = res->ids + i * k;
759
+ float* heap_dis = res->val + i * k;
760
+ const uint8_t* qcode = qcodes + i * code_size;
748
761
 
749
762
  if (init_finalize_heap)
750
- maxheap_heapify (k, heap_dis, heap_ids);
763
+ maxheap_heapify(k, heap_dis, heap_ids);
751
764
 
752
- const uint8_t * bcode = bcodes;
765
+ const uint8_t* bcode = bcodes;
753
766
  for (size_t j = 0; j < nb; j++) {
754
767
  float dis = 0;
755
- const float * tab = sdc_table.data();
768
+ const float* tab = sdc_table.data();
756
769
  for (int m = 0; m < M; m++) {
757
770
  dis += tab[bcode[m] + qcode[m] * ksub];
758
771
  tab += ksub * ksub;
759
772
  }
760
773
  if (dis < heap_dis[0]) {
761
- maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
774
+ maxheap_replace_top(k, heap_dis, heap_ids, dis, j);
762
775
  }
763
776
  bcode += code_size;
764
777
  }
765
778
 
766
779
  if (init_finalize_heap)
767
- maxheap_reorder (k, heap_dis, heap_ids);
780
+ maxheap_reorder(k, heap_dis, heap_ids);
768
781
  }
769
-
770
782
  }
771
783
 
772
-
773
-
774
- } // namespace faiss
784
+ } // namespace faiss