faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -278,13 +278,15 @@ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
278
278
  ****************************************************/
279
279
 
280
280
  // for imbalance factor
281
- double tot = 0.0, uf = 0.0;
281
+ double tot = 0.0;
282
+ double uf = 0.0;
282
283
 
283
284
  idx_t end = n;
284
285
  for (idx_t k = nclusters - 1; k >= 0; k--) {
285
- idx_t start = T.at(k, end - 1);
286
- float sum = std::accumulate(&arr[start], &arr[end], 0.0f);
287
- idx_t size = end - start;
286
+ const idx_t start = T.at(k, end - 1);
287
+ const float sum =
288
+ std::accumulate(arr.data() + start, arr.data() + end, 0.0f);
289
+ const idx_t size = end - start;
288
290
  FAISS_THROW_IF_NOT_FMT(
289
291
  size > 0, "Cluster %d: size %d", int(k), int(size));
290
292
  centroids[k] = sum / size;
@@ -122,30 +122,70 @@ void pq4_pack_codes_range(
122
122
  }
123
123
  }
124
124
 
125
+ namespace {
126
+
127
+ // get the specific address of the vector inside a block
128
+ // shift is used for determine the if the saved in bits 0..3 (false) or
129
+ // bits 4..7 (true)
130
+ uint8_t get_vector_specific_address(
131
+ size_t bbs,
132
+ size_t vector_id,
133
+ size_t sq,
134
+ bool& shift) {
135
+ // get the vector_id inside the block
136
+ vector_id = vector_id % bbs;
137
+ shift = vector_id > 15;
138
+ vector_id = vector_id & 15;
139
+
140
+ // get the address of the vector in sq
141
+ size_t address;
142
+ if (vector_id < 8) {
143
+ address = vector_id << 1;
144
+ } else {
145
+ address = ((vector_id - 8) << 1) + 1;
146
+ }
147
+ if (sq & 1) {
148
+ address += 16;
149
+ }
150
+ return (sq >> 1) * bbs + address;
151
+ }
152
+
153
+ } // anonymous namespace
154
+
125
155
  uint8_t pq4_get_packed_element(
126
156
  const uint8_t* data,
127
157
  size_t bbs,
128
158
  size_t nsq,
129
- size_t i,
159
+ size_t vector_id,
130
160
  size_t sq) {
131
161
  // move to correct bbs-sized block
132
- data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
133
- sq = sq & 1;
134
- i = i % bbs;
135
-
136
- // another step
137
- data += (i / 32) * 32;
138
- i = i % 32;
139
-
140
- if (sq == 1) {
141
- data += 16;
162
+ // number of blocks * block size
163
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
164
+ bool shift;
165
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
166
+ if (shift) {
167
+ return data[address] >> 4;
168
+ } else {
169
+ return data[address] & 15;
142
170
  }
143
- const uint8_t iperm0[16] = {
144
- 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
145
- if (i < 16) {
146
- return data[iperm0[i]] & 15;
171
+ }
172
+
173
+ void pq4_set_packed_element(
174
+ uint8_t* data,
175
+ uint8_t code,
176
+ size_t bbs,
177
+ size_t nsq,
178
+ size_t vector_id,
179
+ size_t sq) {
180
+ // move to correct bbs-sized block
181
+ // number of blocks * block size
182
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
183
+ bool shift;
184
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
185
+ if (shift) {
186
+ data[address] = (code << 4) | (data[address] & 15);
147
187
  } else {
148
- return data[iperm0[i - 16]] >> 4;
188
+ data[address] = code | (data[address] & ~15);
149
189
  }
150
190
  }
151
191
 
@@ -26,7 +26,7 @@ namespace faiss {
26
26
  * The unused bytes are set to 0.
27
27
  *
28
28
  * @param codes input codes, size (ntotal, ceil(M / 2))
29
- * @param nototal number of input codes
29
+ * @param ntotal number of input codes
30
30
  * @param nb output number of codes (ntotal rounded up to a multiple of
31
31
  * bbs)
32
32
  * @param M2 number of sub-quantizers (=M rounded up to a muliple of 2)
@@ -61,14 +61,27 @@ void pq4_pack_codes_range(
61
61
 
62
62
  /** get a single element from a packed codes table
63
63
  *
64
- * @param i vector id
64
+ * @param vector_id vector id
65
65
  * @param sq subquantizer (< nsq)
66
66
  */
67
67
  uint8_t pq4_get_packed_element(
68
68
  const uint8_t* data,
69
69
  size_t bbs,
70
70
  size_t nsq,
71
- size_t i,
71
+ size_t vector_id,
72
+ size_t sq);
73
+
74
+ /** set a single element "code" into a packed codes table
75
+ *
76
+ * @param vector_id vector id
77
+ * @param sq subquantizer (< nsq)
78
+ */
79
+ void pq4_set_packed_element(
80
+ uint8_t* data,
81
+ uint8_t code,
82
+ size_t bbs,
83
+ size_t nsq,
84
+ size_t vector_id,
72
85
  size_t sq);
73
86
 
74
87
  /** Pack Look-up table for consumption by the kernel.
@@ -88,8 +101,9 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
88
101
  * @param nsq number of sub-quantizers (muliple of 2)
89
102
  * @param codes packed codes array
90
103
  * @param LUT packed look-up table
104
+ * @param scaler scaler to scale the encoded norm
91
105
  */
92
- template <class ResultHandler>
106
+ template <class ResultHandler, class Scaler>
93
107
  void pq4_accumulate_loop(
94
108
  int nq,
95
109
  size_t nb,
@@ -97,7 +111,8 @@ void pq4_accumulate_loop(
97
111
  int nsq,
98
112
  const uint8_t* codes,
99
113
  const uint8_t* LUT,
100
- ResultHandler& res);
114
+ ResultHandler& res,
115
+ const Scaler& scaler);
101
116
 
102
117
  /* qbs versions, supported only for bbs=32.
103
118
  *
@@ -141,20 +156,22 @@ int pq4_pack_LUT_qbs_q_map(
141
156
 
142
157
  /** Run accumulation loop.
143
158
  *
144
- * @param qbs 4-bit encded number of queries
159
+ * @param qbs 4-bit encoded number of queries
145
160
  * @param nb number of database codes (mutliple of bbs)
146
161
  * @param nsq number of sub-quantizers
147
162
  * @param codes encoded database vectors (packed)
148
163
  * @param LUT look-up table (packed)
149
164
  * @param res call-back for the resutls
165
+ * @param scaler scaler to scale the encoded norm
150
166
  */
151
- template <class ResultHandler>
167
+ template <class ResultHandler, class Scaler>
152
168
  void pq4_accumulate_loop_qbs(
153
169
  int qbs,
154
170
  size_t nb,
155
171
  int nsq,
156
172
  const uint8_t* codes,
157
173
  const uint8_t* LUT,
158
- ResultHandler& res);
174
+ ResultHandler& res,
175
+ const Scaler& scaler);
159
176
 
160
177
  } // namespace faiss
@@ -8,6 +8,7 @@
8
8
  #include <faiss/impl/pq4_fast_scan.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LookupTableScaler.h>
11
12
  #include <faiss/impl/simd_result_handlers.h>
12
13
 
13
14
  namespace faiss {
@@ -26,12 +27,13 @@ namespace {
26
27
  * writes results in a ResultHandler
27
28
  */
28
29
 
29
- template <int NQ, int BB, class ResultHandler>
30
+ template <int NQ, int BB, class ResultHandler, class Scaler>
30
31
  void kernel_accumulate_block(
31
32
  int nsq,
32
33
  const uint8_t* codes,
33
34
  const uint8_t* LUT,
34
- ResultHandler& res) {
35
+ ResultHandler& res,
36
+ const Scaler& scaler) {
35
37
  // distance accumulators
36
38
  simd16uint16 accu[NQ][BB][4];
37
39
 
@@ -44,7 +46,7 @@ void kernel_accumulate_block(
44
46
  }
45
47
  }
46
48
 
47
- for (int sq = 0; sq < nsq; sq += 2) {
49
+ for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
48
50
  simd32uint8 lut_cache[NQ];
49
51
  for (int q = 0; q < NQ; q++) {
50
52
  lut_cache[q] = simd32uint8(LUT);
@@ -72,6 +74,35 @@ void kernel_accumulate_block(
72
74
  }
73
75
  }
74
76
 
77
+ for (int sq = 0; sq < scaler.nscale; sq += 2) {
78
+ simd32uint8 lut_cache[NQ];
79
+ for (int q = 0; q < NQ; q++) {
80
+ lut_cache[q] = simd32uint8(LUT);
81
+ LUT += 32;
82
+ }
83
+
84
+ for (int b = 0; b < BB; b++) {
85
+ simd32uint8 c = simd32uint8(codes);
86
+ codes += 32;
87
+ simd32uint8 mask(15);
88
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
89
+ simd32uint8 clo = c & mask;
90
+
91
+ for (int q = 0; q < NQ; q++) {
92
+ simd32uint8 lut = lut_cache[q];
93
+
94
+ simd32uint8 res0 = scaler.lookup(lut, clo);
95
+ accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
96
+ accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
97
+
98
+ simd32uint8 res1 = scaler.lookup(lut, chi);
99
+ accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
100
+ accu[q][b][3] +=
101
+ scaler.scale_hi(res1); // handle vectors 24..31
102
+ }
103
+ }
104
+ }
105
+
75
106
  for (int q = 0; q < NQ; q++) {
76
107
  for (int b = 0; b < BB; b++) {
77
108
  accu[q][b][0] -= accu[q][b][1] << 8;
@@ -85,17 +116,18 @@ void kernel_accumulate_block(
85
116
  }
86
117
  }
87
118
 
88
- template <int NQ, int BB, class ResultHandler>
119
+ template <int NQ, int BB, class ResultHandler, class Scaler>
89
120
  void accumulate_fixed_blocks(
90
121
  size_t nb,
91
122
  int nsq,
92
123
  const uint8_t* codes,
93
124
  const uint8_t* LUT,
94
- ResultHandler& res) {
125
+ ResultHandler& res,
126
+ const Scaler& scaler) {
95
127
  constexpr int bbs = 32 * BB;
96
128
  for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
97
129
  FixedStorageHandler<NQ, 2 * BB> res2;
98
- kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
130
+ kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
99
131
  res.set_block_origin(0, j0);
100
132
  res2.to_other_handler(res);
101
133
  codes += bbs * nsq / 2;
@@ -104,7 +136,7 @@ void accumulate_fixed_blocks(
104
136
 
105
137
  } // anonymous namespace
106
138
 
107
- template <class ResultHandler>
139
+ template <class ResultHandler, class Scaler>
108
140
  void pq4_accumulate_loop(
109
141
  int nq,
110
142
  size_t nb,
@@ -112,15 +144,16 @@ void pq4_accumulate_loop(
112
144
  int nsq,
113
145
  const uint8_t* codes,
114
146
  const uint8_t* LUT,
115
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const Scaler& scaler) {
116
149
  FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
117
150
  FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
118
151
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
119
152
  FAISS_THROW_IF_NOT(nb % bbs == 0);
120
153
 
121
- #define DISPATCH(NQ, BB) \
122
- case NQ * 1000 + BB: \
123
- accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
154
+ #define DISPATCH(NQ, BB) \
155
+ case NQ * 1000 + BB: \
156
+ accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
124
157
  break
125
158
 
126
159
  switch (nq * 1000 + bbs / 32) {
@@ -141,20 +174,28 @@ void pq4_accumulate_loop(
141
174
 
142
175
  // explicit template instantiations
143
176
 
144
- #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
145
- template void pq4_accumulate_loop<TH<C, with_id_map>>( \
146
- int, \
147
- size_t, \
148
- int, \
149
- int, \
150
- const uint8_t*, \
151
- const uint8_t*, \
152
- TH<C, with_id_map>&);
153
-
154
- #define INSTANTIATE_3(C, with_id_map) \
155
- INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map) \
156
- INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map) \
157
- INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map)
177
+ #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \
178
+ template void pq4_accumulate_loop<TH<C, with_id_map>, S>( \
179
+ int, \
180
+ size_t, \
181
+ int, \
182
+ int, \
183
+ const uint8_t*, \
184
+ const uint8_t*, \
185
+ TH<C, with_id_map>&, \
186
+ const S&);
187
+
188
+ using DS = DummyScaler;
189
+ using NS = NormTableScaler;
190
+
191
+ #define INSTANTIATE_3(C, with_id_map) \
192
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \
193
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \
194
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \
195
+ \
196
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \
197
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \
198
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS)
158
199
 
159
200
  using Csi = CMax<uint16_t, int>;
160
201
  INSTANTIATE_3(Csi, false);
@@ -8,6 +8,7 @@
8
8
  #include <faiss/impl/pq4_fast_scan.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LookupTableScaler.h>
11
12
  #include <faiss/impl/simd_result_handlers.h>
12
13
  #include <faiss/utils/simdlib.h>
13
14
 
@@ -27,15 +28,17 @@ namespace {
27
28
  * writes results in a ResultHandler
28
29
  */
29
30
 
30
- template <int NQ, class ResultHandler>
31
+ template <int NQ, class ResultHandler, class Scaler>
31
32
  void kernel_accumulate_block(
32
33
  int nsq,
33
34
  const uint8_t* codes,
34
35
  const uint8_t* LUT,
35
- ResultHandler& res) {
36
+ ResultHandler& res,
37
+ const Scaler& scaler) {
36
38
  // dummy alloc to keep the windows compiler happy
37
39
  constexpr int NQA = NQ > 0 ? NQ : 1;
38
40
  // distance accumulators
41
+ // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7
39
42
  simd16uint16 accu[NQA][4];
40
43
 
41
44
  for (int q = 0; q < NQ; q++) {
@@ -45,7 +48,7 @@ void kernel_accumulate_block(
45
48
  }
46
49
 
47
50
  // _mm_prefetch(codes + 768, 0);
48
- for (int sq = 0; sq < nsq; sq += 2) {
51
+ for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
49
52
  // prefetch
50
53
  simd32uint8 c(codes);
51
54
  codes += 32;
@@ -71,6 +74,31 @@ void kernel_accumulate_block(
71
74
  }
72
75
  }
73
76
 
77
+ for (int sq = 0; sq < scaler.nscale; sq += 2) {
78
+ // prefetch
79
+ simd32uint8 c(codes);
80
+ codes += 32;
81
+
82
+ simd32uint8 mask(0xf);
83
+ // shift op does not exist for int8...
84
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
85
+ simd32uint8 clo = c & mask;
86
+
87
+ for (int q = 0; q < NQ; q++) {
88
+ // load LUTs for 2 quantizers
89
+ simd32uint8 lut(LUT);
90
+ LUT += 32;
91
+
92
+ simd32uint8 res0 = scaler.lookup(lut, clo);
93
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
94
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
95
+
96
+ simd32uint8 res1 = scaler.lookup(lut, chi);
97
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
98
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
99
+ }
100
+ }
101
+
74
102
  for (int q = 0; q < NQ; q++) {
75
103
  accu[q][0] -= accu[q][1] << 8;
76
104
  simd16uint16 dis0 = combine2x2(accu[q][0], accu[q][1]);
@@ -81,13 +109,14 @@ void kernel_accumulate_block(
81
109
  }
82
110
 
83
111
  // handle at most 4 blocks of queries
84
- template <int QBS, class ResultHandler>
112
+ template <int QBS, class ResultHandler, class Scaler>
85
113
  void accumulate_q_4step(
86
114
  size_t ntotal2,
87
115
  int nsq,
88
116
  const uint8_t* codes,
89
117
  const uint8_t* LUT0,
90
- ResultHandler& res) {
118
+ ResultHandler& res,
119
+ const Scaler& scaler) {
91
120
  constexpr int Q1 = QBS & 15;
92
121
  constexpr int Q2 = (QBS >> 4) & 15;
93
122
  constexpr int Q3 = (QBS >> 8) & 15;
@@ -97,21 +126,21 @@ void accumulate_q_4step(
97
126
  for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
98
127
  FixedStorageHandler<SQ, 2> res2;
99
128
  const uint8_t* LUT = LUT0;
100
- kernel_accumulate_block<Q1>(nsq, codes, LUT, res2);
129
+ kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
101
130
  LUT += Q1 * nsq * 16;
102
131
  if (Q2 > 0) {
103
132
  res2.set_block_origin(Q1, 0);
104
- kernel_accumulate_block<Q2>(nsq, codes, LUT, res2);
133
+ kernel_accumulate_block<Q2>(nsq, codes, LUT, res2, scaler);
105
134
  LUT += Q2 * nsq * 16;
106
135
  }
107
136
  if (Q3 > 0) {
108
137
  res2.set_block_origin(Q1 + Q2, 0);
109
- kernel_accumulate_block<Q3>(nsq, codes, LUT, res2);
138
+ kernel_accumulate_block<Q3>(nsq, codes, LUT, res2, scaler);
110
139
  LUT += Q3 * nsq * 16;
111
140
  }
112
141
  if (Q4 > 0) {
113
142
  res2.set_block_origin(Q1 + Q2 + Q3, 0);
114
- kernel_accumulate_block<Q4>(nsq, codes, LUT, res2);
143
+ kernel_accumulate_block<Q4>(nsq, codes, LUT, res2, scaler);
115
144
  }
116
145
  res.set_block_origin(0, j0);
117
146
  res2.to_other_handler(res);
@@ -119,29 +148,31 @@ void accumulate_q_4step(
119
148
  }
120
149
  }
121
150
 
122
- template <int NQ, class ResultHandler>
151
+ template <int NQ, class ResultHandler, class Scaler>
123
152
  void kernel_accumulate_block_loop(
124
153
  size_t ntotal2,
125
154
  int nsq,
126
155
  const uint8_t* codes,
127
156
  const uint8_t* LUT,
128
- ResultHandler& res) {
157
+ ResultHandler& res,
158
+ const Scaler& scaler) {
129
159
  for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
130
160
  res.set_block_origin(0, j0);
131
161
  kernel_accumulate_block<NQ, ResultHandler>(
132
- nsq, codes + j0 * nsq / 2, LUT, res);
162
+ nsq, codes + j0 * nsq / 2, LUT, res, scaler);
133
163
  }
134
164
  }
135
165
 
136
166
  // non-template version of accumulate kernel -- dispatches dynamically
137
- template <class ResultHandler>
167
+ template <class ResultHandler, class Scaler>
138
168
  void accumulate(
139
169
  int nq,
140
170
  size_t ntotal2,
141
171
  int nsq,
142
172
  const uint8_t* codes,
143
173
  const uint8_t* LUT,
144
- ResultHandler& res) {
174
+ ResultHandler& res,
175
+ const Scaler& scaler) {
145
176
  assert(nsq % 2 == 0);
146
177
  assert(is_aligned_pointer(codes));
147
178
  assert(is_aligned_pointer(LUT));
@@ -149,7 +180,7 @@ void accumulate(
149
180
  #define DISPATCH(NQ) \
150
181
  case NQ: \
151
182
  kernel_accumulate_block_loop<NQ, ResultHandler>( \
152
- ntotal2, nsq, codes, LUT, res); \
183
+ ntotal2, nsq, codes, LUT, res, scaler); \
153
184
  return
154
185
 
155
186
  switch (nq) {
@@ -165,23 +196,24 @@ void accumulate(
165
196
 
166
197
  } // namespace
167
198
 
168
- template <class ResultHandler>
199
+ template <class ResultHandler, class Scaler>
169
200
  void pq4_accumulate_loop_qbs(
170
201
  int qbs,
171
202
  size_t ntotal2,
172
203
  int nsq,
173
204
  const uint8_t* codes,
174
205
  const uint8_t* LUT0,
175
- ResultHandler& res) {
206
+ ResultHandler& res,
207
+ const Scaler& scaler) {
176
208
  assert(nsq % 2 == 0);
177
209
  assert(is_aligned_pointer(codes));
178
210
  assert(is_aligned_pointer(LUT0));
179
211
 
180
212
  // try out optimized versions
181
213
  switch (qbs) {
182
- #define DISPATCH(QBS) \
183
- case QBS: \
184
- accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res); \
214
+ #define DISPATCH(QBS) \
215
+ case QBS: \
216
+ accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res, scaler); \
185
217
  return;
186
218
  DISPATCH(0x3333); // 12
187
219
  DISPATCH(0x2333); // 11
@@ -219,9 +251,10 @@ void pq4_accumulate_loop_qbs(
219
251
  int nq = qi & 15;
220
252
  qi >>= 4;
221
253
  res.set_block_origin(i0, j0);
222
- #define DISPATCH(NQ) \
223
- case NQ: \
224
- kernel_accumulate_block<NQ, ResultHandler>(nsq, codes, LUT, res); \
254
+ #define DISPATCH(NQ) \
255
+ case NQ: \
256
+ kernel_accumulate_block<NQ, ResultHandler>( \
257
+ nsq, codes, LUT, res, scaler); \
225
258
  break
226
259
  switch (nq) {
227
260
  DISPATCH(1);
@@ -241,9 +274,23 @@ void pq4_accumulate_loop_qbs(
241
274
 
242
275
  // explicit template instantiations
243
276
 
244
- #define INSTANTIATE_ACCUMULATE_Q(RH) \
245
- template void pq4_accumulate_loop_qbs<RH>( \
246
- int, size_t, int, const uint8_t*, const uint8_t*, RH&);
277
+ #define INSTANTIATE_ACCUMULATE_Q(RH) \
278
+ template void pq4_accumulate_loop_qbs<RH, DummyScaler>( \
279
+ int, \
280
+ size_t, \
281
+ int, \
282
+ const uint8_t*, \
283
+ const uint8_t*, \
284
+ RH&, \
285
+ const DummyScaler&); \
286
+ template void pq4_accumulate_loop_qbs<RH, NormTableScaler>( \
287
+ int, \
288
+ size_t, \
289
+ int, \
290
+ const uint8_t*, \
291
+ const uint8_t*, \
292
+ RH&, \
293
+ const NormTableScaler&);
247
294
 
248
295
  using Csi = CMax<uint16_t, int>;
249
296
  INSTANTIATE_ACCUMULATE_Q(SingleResultHandler<Csi>)
@@ -293,7 +340,8 @@ void accumulate_to_mem(
293
340
  uint16_t* accu) {
294
341
  FAISS_THROW_IF_NOT(ntotal2 % 32 == 0);
295
342
  StoreResultHandler handler(accu, ntotal2);
296
- accumulate(nq, ntotal2, nsq, codes, LUT, handler);
343
+ DummyScaler scaler;
344
+ accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler);
297
345
  }
298
346
 
299
347
  int pq4_preferred_qbs(int n) {