faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -278,13 +278,15 @@ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
278
278
  ****************************************************/
279
279
 
280
280
  // for imbalance factor
281
- double tot = 0.0, uf = 0.0;
281
+ double tot = 0.0;
282
+ double uf = 0.0;
282
283
 
283
284
  idx_t end = n;
284
285
  for (idx_t k = nclusters - 1; k >= 0; k--) {
285
- idx_t start = T.at(k, end - 1);
286
- float sum = std::accumulate(&arr[start], &arr[end], 0.0f);
287
- idx_t size = end - start;
286
+ const idx_t start = T.at(k, end - 1);
287
+ const float sum =
288
+ std::accumulate(arr.data() + start, arr.data() + end, 0.0f);
289
+ const idx_t size = end - start;
288
290
  FAISS_THROW_IF_NOT_FMT(
289
291
  size > 0, "Cluster %d: size %d", int(k), int(size));
290
292
  centroids[k] = sum / size;
@@ -122,30 +122,70 @@ void pq4_pack_codes_range(
122
122
  }
123
123
  }
124
124
 
125
+ namespace {
126
+
127
+ // get the specific address of the vector inside a block
128
+ // shift is used for determine the if the saved in bits 0..3 (false) or
129
+ // bits 4..7 (true)
130
+ uint8_t get_vector_specific_address(
131
+ size_t bbs,
132
+ size_t vector_id,
133
+ size_t sq,
134
+ bool& shift) {
135
+ // get the vector_id inside the block
136
+ vector_id = vector_id % bbs;
137
+ shift = vector_id > 15;
138
+ vector_id = vector_id & 15;
139
+
140
+ // get the address of the vector in sq
141
+ size_t address;
142
+ if (vector_id < 8) {
143
+ address = vector_id << 1;
144
+ } else {
145
+ address = ((vector_id - 8) << 1) + 1;
146
+ }
147
+ if (sq & 1) {
148
+ address += 16;
149
+ }
150
+ return (sq >> 1) * bbs + address;
151
+ }
152
+
153
+ } // anonymous namespace
154
+
125
155
  uint8_t pq4_get_packed_element(
126
156
  const uint8_t* data,
127
157
  size_t bbs,
128
158
  size_t nsq,
129
- size_t i,
159
+ size_t vector_id,
130
160
  size_t sq) {
131
161
  // move to correct bbs-sized block
132
- data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
133
- sq = sq & 1;
134
- i = i % bbs;
135
-
136
- // another step
137
- data += (i / 32) * 32;
138
- i = i % 32;
139
-
140
- if (sq == 1) {
141
- data += 16;
162
+ // number of blocks * block size
163
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
164
+ bool shift;
165
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
166
+ if (shift) {
167
+ return data[address] >> 4;
168
+ } else {
169
+ return data[address] & 15;
142
170
  }
143
- const uint8_t iperm0[16] = {
144
- 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
145
- if (i < 16) {
146
- return data[iperm0[i]] & 15;
171
+ }
172
+
173
+ void pq4_set_packed_element(
174
+ uint8_t* data,
175
+ uint8_t code,
176
+ size_t bbs,
177
+ size_t nsq,
178
+ size_t vector_id,
179
+ size_t sq) {
180
+ // move to correct bbs-sized block
181
+ // number of blocks * block size
182
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
183
+ bool shift;
184
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
185
+ if (shift) {
186
+ data[address] = (code << 4) | (data[address] & 15);
147
187
  } else {
148
- return data[iperm0[i - 16]] >> 4;
188
+ data[address] = code | (data[address] & ~15);
149
189
  }
150
190
  }
151
191
 
@@ -26,7 +26,7 @@ namespace faiss {
26
26
  * The unused bytes are set to 0.
27
27
  *
28
28
  * @param codes input codes, size (ntotal, ceil(M / 2))
29
- * @param nototal number of input codes
29
+ * @param ntotal number of input codes
30
30
  * @param nb output number of codes (ntotal rounded up to a multiple of
31
31
  * bbs)
32
32
  * @param M2 number of sub-quantizers (=M rounded up to a muliple of 2)
@@ -61,14 +61,27 @@ void pq4_pack_codes_range(
61
61
 
62
62
  /** get a single element from a packed codes table
63
63
  *
64
- * @param i vector id
64
+ * @param vector_id vector id
65
65
  * @param sq subquantizer (< nsq)
66
66
  */
67
67
  uint8_t pq4_get_packed_element(
68
68
  const uint8_t* data,
69
69
  size_t bbs,
70
70
  size_t nsq,
71
- size_t i,
71
+ size_t vector_id,
72
+ size_t sq);
73
+
74
+ /** set a single element "code" into a packed codes table
75
+ *
76
+ * @param vector_id vector id
77
+ * @param sq subquantizer (< nsq)
78
+ */
79
+ void pq4_set_packed_element(
80
+ uint8_t* data,
81
+ uint8_t code,
82
+ size_t bbs,
83
+ size_t nsq,
84
+ size_t vector_id,
72
85
  size_t sq);
73
86
 
74
87
  /** Pack Look-up table for consumption by the kernel.
@@ -88,8 +101,9 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
88
101
  * @param nsq number of sub-quantizers (muliple of 2)
89
102
  * @param codes packed codes array
90
103
  * @param LUT packed look-up table
104
+ * @param scaler scaler to scale the encoded norm
91
105
  */
92
- template <class ResultHandler>
106
+ template <class ResultHandler, class Scaler>
93
107
  void pq4_accumulate_loop(
94
108
  int nq,
95
109
  size_t nb,
@@ -97,7 +111,8 @@ void pq4_accumulate_loop(
97
111
  int nsq,
98
112
  const uint8_t* codes,
99
113
  const uint8_t* LUT,
100
- ResultHandler& res);
114
+ ResultHandler& res,
115
+ const Scaler& scaler);
101
116
 
102
117
  /* qbs versions, supported only for bbs=32.
103
118
  *
@@ -141,20 +156,22 @@ int pq4_pack_LUT_qbs_q_map(
141
156
 
142
157
  /** Run accumulation loop.
143
158
  *
144
- * @param qbs 4-bit encded number of queries
159
+ * @param qbs 4-bit encoded number of queries
145
160
  * @param nb number of database codes (mutliple of bbs)
146
161
  * @param nsq number of sub-quantizers
147
162
  * @param codes encoded database vectors (packed)
148
163
  * @param LUT look-up table (packed)
149
164
  * @param res call-back for the resutls
165
+ * @param scaler scaler to scale the encoded norm
150
166
  */
151
- template <class ResultHandler>
167
+ template <class ResultHandler, class Scaler>
152
168
  void pq4_accumulate_loop_qbs(
153
169
  int qbs,
154
170
  size_t nb,
155
171
  int nsq,
156
172
  const uint8_t* codes,
157
173
  const uint8_t* LUT,
158
- ResultHandler& res);
174
+ ResultHandler& res,
175
+ const Scaler& scaler);
159
176
 
160
177
  } // namespace faiss
@@ -8,6 +8,7 @@
8
8
  #include <faiss/impl/pq4_fast_scan.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LookupTableScaler.h>
11
12
  #include <faiss/impl/simd_result_handlers.h>
12
13
 
13
14
  namespace faiss {
@@ -26,12 +27,13 @@ namespace {
26
27
  * writes results in a ResultHandler
27
28
  */
28
29
 
29
- template <int NQ, int BB, class ResultHandler>
30
+ template <int NQ, int BB, class ResultHandler, class Scaler>
30
31
  void kernel_accumulate_block(
31
32
  int nsq,
32
33
  const uint8_t* codes,
33
34
  const uint8_t* LUT,
34
- ResultHandler& res) {
35
+ ResultHandler& res,
36
+ const Scaler& scaler) {
35
37
  // distance accumulators
36
38
  simd16uint16 accu[NQ][BB][4];
37
39
 
@@ -44,7 +46,7 @@ void kernel_accumulate_block(
44
46
  }
45
47
  }
46
48
 
47
- for (int sq = 0; sq < nsq; sq += 2) {
49
+ for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
48
50
  simd32uint8 lut_cache[NQ];
49
51
  for (int q = 0; q < NQ; q++) {
50
52
  lut_cache[q] = simd32uint8(LUT);
@@ -72,6 +74,35 @@ void kernel_accumulate_block(
72
74
  }
73
75
  }
74
76
 
77
+ for (int sq = 0; sq < scaler.nscale; sq += 2) {
78
+ simd32uint8 lut_cache[NQ];
79
+ for (int q = 0; q < NQ; q++) {
80
+ lut_cache[q] = simd32uint8(LUT);
81
+ LUT += 32;
82
+ }
83
+
84
+ for (int b = 0; b < BB; b++) {
85
+ simd32uint8 c = simd32uint8(codes);
86
+ codes += 32;
87
+ simd32uint8 mask(15);
88
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
89
+ simd32uint8 clo = c & mask;
90
+
91
+ for (int q = 0; q < NQ; q++) {
92
+ simd32uint8 lut = lut_cache[q];
93
+
94
+ simd32uint8 res0 = scaler.lookup(lut, clo);
95
+ accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
96
+ accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
97
+
98
+ simd32uint8 res1 = scaler.lookup(lut, chi);
99
+ accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
100
+ accu[q][b][3] +=
101
+ scaler.scale_hi(res1); // handle vectors 24..31
102
+ }
103
+ }
104
+ }
105
+
75
106
  for (int q = 0; q < NQ; q++) {
76
107
  for (int b = 0; b < BB; b++) {
77
108
  accu[q][b][0] -= accu[q][b][1] << 8;
@@ -85,17 +116,18 @@ void kernel_accumulate_block(
85
116
  }
86
117
  }
87
118
 
88
- template <int NQ, int BB, class ResultHandler>
119
+ template <int NQ, int BB, class ResultHandler, class Scaler>
89
120
  void accumulate_fixed_blocks(
90
121
  size_t nb,
91
122
  int nsq,
92
123
  const uint8_t* codes,
93
124
  const uint8_t* LUT,
94
- ResultHandler& res) {
125
+ ResultHandler& res,
126
+ const Scaler& scaler) {
95
127
  constexpr int bbs = 32 * BB;
96
128
  for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
97
129
  FixedStorageHandler<NQ, 2 * BB> res2;
98
- kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
130
+ kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
99
131
  res.set_block_origin(0, j0);
100
132
  res2.to_other_handler(res);
101
133
  codes += bbs * nsq / 2;
@@ -104,7 +136,7 @@ void accumulate_fixed_blocks(
104
136
 
105
137
  } // anonymous namespace
106
138
 
107
- template <class ResultHandler>
139
+ template <class ResultHandler, class Scaler>
108
140
  void pq4_accumulate_loop(
109
141
  int nq,
110
142
  size_t nb,
@@ -112,15 +144,16 @@ void pq4_accumulate_loop(
112
144
  int nsq,
113
145
  const uint8_t* codes,
114
146
  const uint8_t* LUT,
115
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const Scaler& scaler) {
116
149
  FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
117
150
  FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
118
151
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
119
152
  FAISS_THROW_IF_NOT(nb % bbs == 0);
120
153
 
121
- #define DISPATCH(NQ, BB) \
122
- case NQ * 1000 + BB: \
123
- accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
154
+ #define DISPATCH(NQ, BB) \
155
+ case NQ * 1000 + BB: \
156
+ accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
124
157
  break
125
158
 
126
159
  switch (nq * 1000 + bbs / 32) {
@@ -141,20 +174,28 @@ void pq4_accumulate_loop(
141
174
 
142
175
  // explicit template instantiations
143
176
 
144
- #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
145
- template void pq4_accumulate_loop<TH<C, with_id_map>>( \
146
- int, \
147
- size_t, \
148
- int, \
149
- int, \
150
- const uint8_t*, \
151
- const uint8_t*, \
152
- TH<C, with_id_map>&);
153
-
154
- #define INSTANTIATE_3(C, with_id_map) \
155
- INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map) \
156
- INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map) \
157
- INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map)
177
+ #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \
178
+ template void pq4_accumulate_loop<TH<C, with_id_map>, S>( \
179
+ int, \
180
+ size_t, \
181
+ int, \
182
+ int, \
183
+ const uint8_t*, \
184
+ const uint8_t*, \
185
+ TH<C, with_id_map>&, \
186
+ const S&);
187
+
188
+ using DS = DummyScaler;
189
+ using NS = NormTableScaler;
190
+
191
+ #define INSTANTIATE_3(C, with_id_map) \
192
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \
193
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \
194
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \
195
+ \
196
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \
197
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \
198
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS)
158
199
 
159
200
  using Csi = CMax<uint16_t, int>;
160
201
  INSTANTIATE_3(Csi, false);
@@ -8,6 +8,7 @@
8
8
  #include <faiss/impl/pq4_fast_scan.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LookupTableScaler.h>
11
12
  #include <faiss/impl/simd_result_handlers.h>
12
13
  #include <faiss/utils/simdlib.h>
13
14
 
@@ -27,15 +28,17 @@ namespace {
27
28
  * writes results in a ResultHandler
28
29
  */
29
30
 
30
- template <int NQ, class ResultHandler>
31
+ template <int NQ, class ResultHandler, class Scaler>
31
32
  void kernel_accumulate_block(
32
33
  int nsq,
33
34
  const uint8_t* codes,
34
35
  const uint8_t* LUT,
35
- ResultHandler& res) {
36
+ ResultHandler& res,
37
+ const Scaler& scaler) {
36
38
  // dummy alloc to keep the windows compiler happy
37
39
  constexpr int NQA = NQ > 0 ? NQ : 1;
38
40
  // distance accumulators
41
+ // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7
39
42
  simd16uint16 accu[NQA][4];
40
43
 
41
44
  for (int q = 0; q < NQ; q++) {
@@ -45,7 +48,7 @@ void kernel_accumulate_block(
45
48
  }
46
49
 
47
50
  // _mm_prefetch(codes + 768, 0);
48
- for (int sq = 0; sq < nsq; sq += 2) {
51
+ for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
49
52
  // prefetch
50
53
  simd32uint8 c(codes);
51
54
  codes += 32;
@@ -71,6 +74,31 @@ void kernel_accumulate_block(
71
74
  }
72
75
  }
73
76
 
77
+ for (int sq = 0; sq < scaler.nscale; sq += 2) {
78
+ // prefetch
79
+ simd32uint8 c(codes);
80
+ codes += 32;
81
+
82
+ simd32uint8 mask(0xf);
83
+ // shift op does not exist for int8...
84
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
85
+ simd32uint8 clo = c & mask;
86
+
87
+ for (int q = 0; q < NQ; q++) {
88
+ // load LUTs for 2 quantizers
89
+ simd32uint8 lut(LUT);
90
+ LUT += 32;
91
+
92
+ simd32uint8 res0 = scaler.lookup(lut, clo);
93
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
94
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
95
+
96
+ simd32uint8 res1 = scaler.lookup(lut, chi);
97
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
98
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
99
+ }
100
+ }
101
+
74
102
  for (int q = 0; q < NQ; q++) {
75
103
  accu[q][0] -= accu[q][1] << 8;
76
104
  simd16uint16 dis0 = combine2x2(accu[q][0], accu[q][1]);
@@ -81,13 +109,14 @@ void kernel_accumulate_block(
81
109
  }
82
110
 
83
111
  // handle at most 4 blocks of queries
84
- template <int QBS, class ResultHandler>
112
+ template <int QBS, class ResultHandler, class Scaler>
85
113
  void accumulate_q_4step(
86
114
  size_t ntotal2,
87
115
  int nsq,
88
116
  const uint8_t* codes,
89
117
  const uint8_t* LUT0,
90
- ResultHandler& res) {
118
+ ResultHandler& res,
119
+ const Scaler& scaler) {
91
120
  constexpr int Q1 = QBS & 15;
92
121
  constexpr int Q2 = (QBS >> 4) & 15;
93
122
  constexpr int Q3 = (QBS >> 8) & 15;
@@ -97,21 +126,21 @@ void accumulate_q_4step(
97
126
  for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
98
127
  FixedStorageHandler<SQ, 2> res2;
99
128
  const uint8_t* LUT = LUT0;
100
- kernel_accumulate_block<Q1>(nsq, codes, LUT, res2);
129
+ kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
101
130
  LUT += Q1 * nsq * 16;
102
131
  if (Q2 > 0) {
103
132
  res2.set_block_origin(Q1, 0);
104
- kernel_accumulate_block<Q2>(nsq, codes, LUT, res2);
133
+ kernel_accumulate_block<Q2>(nsq, codes, LUT, res2, scaler);
105
134
  LUT += Q2 * nsq * 16;
106
135
  }
107
136
  if (Q3 > 0) {
108
137
  res2.set_block_origin(Q1 + Q2, 0);
109
- kernel_accumulate_block<Q3>(nsq, codes, LUT, res2);
138
+ kernel_accumulate_block<Q3>(nsq, codes, LUT, res2, scaler);
110
139
  LUT += Q3 * nsq * 16;
111
140
  }
112
141
  if (Q4 > 0) {
113
142
  res2.set_block_origin(Q1 + Q2 + Q3, 0);
114
- kernel_accumulate_block<Q4>(nsq, codes, LUT, res2);
143
+ kernel_accumulate_block<Q4>(nsq, codes, LUT, res2, scaler);
115
144
  }
116
145
  res.set_block_origin(0, j0);
117
146
  res2.to_other_handler(res);
@@ -119,29 +148,31 @@ void accumulate_q_4step(
119
148
  }
120
149
  }
121
150
 
122
- template <int NQ, class ResultHandler>
151
+ template <int NQ, class ResultHandler, class Scaler>
123
152
  void kernel_accumulate_block_loop(
124
153
  size_t ntotal2,
125
154
  int nsq,
126
155
  const uint8_t* codes,
127
156
  const uint8_t* LUT,
128
- ResultHandler& res) {
157
+ ResultHandler& res,
158
+ const Scaler& scaler) {
129
159
  for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
130
160
  res.set_block_origin(0, j0);
131
161
  kernel_accumulate_block<NQ, ResultHandler>(
132
- nsq, codes + j0 * nsq / 2, LUT, res);
162
+ nsq, codes + j0 * nsq / 2, LUT, res, scaler);
133
163
  }
134
164
  }
135
165
 
136
166
  // non-template version of accumulate kernel -- dispatches dynamically
137
- template <class ResultHandler>
167
+ template <class ResultHandler, class Scaler>
138
168
  void accumulate(
139
169
  int nq,
140
170
  size_t ntotal2,
141
171
  int nsq,
142
172
  const uint8_t* codes,
143
173
  const uint8_t* LUT,
144
- ResultHandler& res) {
174
+ ResultHandler& res,
175
+ const Scaler& scaler) {
145
176
  assert(nsq % 2 == 0);
146
177
  assert(is_aligned_pointer(codes));
147
178
  assert(is_aligned_pointer(LUT));
@@ -149,7 +180,7 @@ void accumulate(
149
180
  #define DISPATCH(NQ) \
150
181
  case NQ: \
151
182
  kernel_accumulate_block_loop<NQ, ResultHandler>( \
152
- ntotal2, nsq, codes, LUT, res); \
183
+ ntotal2, nsq, codes, LUT, res, scaler); \
153
184
  return
154
185
 
155
186
  switch (nq) {
@@ -165,23 +196,24 @@ void accumulate(
165
196
 
166
197
  } // namespace
167
198
 
168
- template <class ResultHandler>
199
+ template <class ResultHandler, class Scaler>
169
200
  void pq4_accumulate_loop_qbs(
170
201
  int qbs,
171
202
  size_t ntotal2,
172
203
  int nsq,
173
204
  const uint8_t* codes,
174
205
  const uint8_t* LUT0,
175
- ResultHandler& res) {
206
+ ResultHandler& res,
207
+ const Scaler& scaler) {
176
208
  assert(nsq % 2 == 0);
177
209
  assert(is_aligned_pointer(codes));
178
210
  assert(is_aligned_pointer(LUT0));
179
211
 
180
212
  // try out optimized versions
181
213
  switch (qbs) {
182
- #define DISPATCH(QBS) \
183
- case QBS: \
184
- accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res); \
214
+ #define DISPATCH(QBS) \
215
+ case QBS: \
216
+ accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res, scaler); \
185
217
  return;
186
218
  DISPATCH(0x3333); // 12
187
219
  DISPATCH(0x2333); // 11
@@ -219,9 +251,10 @@ void pq4_accumulate_loop_qbs(
219
251
  int nq = qi & 15;
220
252
  qi >>= 4;
221
253
  res.set_block_origin(i0, j0);
222
- #define DISPATCH(NQ) \
223
- case NQ: \
224
- kernel_accumulate_block<NQ, ResultHandler>(nsq, codes, LUT, res); \
254
+ #define DISPATCH(NQ) \
255
+ case NQ: \
256
+ kernel_accumulate_block<NQ, ResultHandler>( \
257
+ nsq, codes, LUT, res, scaler); \
225
258
  break
226
259
  switch (nq) {
227
260
  DISPATCH(1);
@@ -241,9 +274,23 @@ void pq4_accumulate_loop_qbs(
241
274
 
242
275
  // explicit template instantiations
243
276
 
244
- #define INSTANTIATE_ACCUMULATE_Q(RH) \
245
- template void pq4_accumulate_loop_qbs<RH>( \
246
- int, size_t, int, const uint8_t*, const uint8_t*, RH&);
277
+ #define INSTANTIATE_ACCUMULATE_Q(RH) \
278
+ template void pq4_accumulate_loop_qbs<RH, DummyScaler>( \
279
+ int, \
280
+ size_t, \
281
+ int, \
282
+ const uint8_t*, \
283
+ const uint8_t*, \
284
+ RH&, \
285
+ const DummyScaler&); \
286
+ template void pq4_accumulate_loop_qbs<RH, NormTableScaler>( \
287
+ int, \
288
+ size_t, \
289
+ int, \
290
+ const uint8_t*, \
291
+ const uint8_t*, \
292
+ RH&, \
293
+ const NormTableScaler&);
247
294
 
248
295
  using Csi = CMax<uint16_t, int>;
249
296
  INSTANTIATE_ACCUMULATE_Q(SingleResultHandler<Csi>)
@@ -293,7 +340,8 @@ void accumulate_to_mem(
293
340
  uint16_t* accu) {
294
341
  FAISS_THROW_IF_NOT(ntotal2 % 32 == 0);
295
342
  StoreResultHandler handler(accu, ntotal2);
296
- accumulate(nq, ntotal2, nsq, codes, LUT, handler);
343
+ DummyScaler scaler;
344
+ accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler);
297
345
  }
298
346
 
299
347
  int pq4_preferred_qbs(int n) {