faiss 0.3.0 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -67,7 +67,7 @@ void IndexPreTransform::train(idx_t n, const float* x) {
67
67
  }
68
68
  }
69
69
  const float* prev_x = x;
70
- ScopeDeleter<float> del;
70
+ std::unique_ptr<const float[]> del;
71
71
 
72
72
  if (verbose) {
73
73
  printf("IndexPreTransform::train: training chain 0 to %d\n",
@@ -102,10 +102,12 @@ void IndexPreTransform::train(idx_t n, const float* x) {
102
102
 
103
103
  float* xt = chain[i]->apply(n, prev_x);
104
104
 
105
- if (prev_x != x)
106
- delete[] prev_x;
105
+ if (prev_x != x) {
106
+ del.reset();
107
+ }
108
+
107
109
  prev_x = xt;
108
- del.set(xt);
110
+ del.reset(xt);
109
111
  }
110
112
 
111
113
  is_trained = true;
@@ -113,11 +115,11 @@ void IndexPreTransform::train(idx_t n, const float* x) {
113
115
 
114
116
  const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
115
117
  const float* prev_x = x;
116
- ScopeDeleter<float> del;
118
+ std::unique_ptr<const float[]> del;
117
119
 
118
120
  for (int i = 0; i < chain.size(); i++) {
119
121
  float* xt = chain[i]->apply(n, prev_x);
120
- ScopeDeleter<float> del2(xt);
122
+ std::unique_ptr<const float[]> del2(xt);
121
123
  del2.swap(del);
122
124
  prev_x = xt;
123
125
  }
@@ -128,11 +130,11 @@ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
128
130
  void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
129
131
  const {
130
132
  const float* next_x = xt;
131
- ScopeDeleter<float> del;
133
+ std::unique_ptr<const float[]> del;
132
134
 
133
135
  for (int i = chain.size() - 1; i >= 0; i--) {
134
136
  float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
135
- ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
137
+ std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
136
138
  chain[i]->reverse_transform(n, next_x, prev_x);
137
139
  del2.swap(del);
138
140
  next_x = prev_x;
@@ -141,9 +143,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
141
143
 
142
144
  void IndexPreTransform::add(idx_t n, const float* x) {
143
145
  FAISS_THROW_IF_NOT(is_trained);
144
- const float* xt = apply_chain(n, x);
145
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
146
- index->add(n, xt);
146
+ TransformedVectors tv(x, apply_chain(n, x));
147
+ index->add(n, tv.x);
147
148
  ntotal = index->ntotal;
148
149
  }
149
150
 
@@ -152,9 +153,8 @@ void IndexPreTransform::add_with_ids(
152
153
  const float* x,
153
154
  const idx_t* xids) {
154
155
  FAISS_THROW_IF_NOT(is_trained);
155
- const float* xt = apply_chain(n, x);
156
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
157
- index->add_with_ids(n, xt, xids);
156
+ TransformedVectors tv(x, apply_chain(n, x));
157
+ index->add_with_ids(n, tv.x, xids);
158
158
  ntotal = index->ntotal;
159
159
  }
160
160
 
@@ -178,7 +178,7 @@ void IndexPreTransform::search(
178
178
  FAISS_THROW_IF_NOT(k > 0);
179
179
  FAISS_THROW_IF_NOT(is_trained);
180
180
  const float* xt = apply_chain(n, x);
181
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
181
+ std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
182
182
  index->search(
183
183
  n, xt, k, distances, labels, extract_index_search_params(params));
184
184
  }
@@ -190,10 +190,9 @@ void IndexPreTransform::range_search(
190
190
  RangeSearchResult* result,
191
191
  const SearchParameters* params) const {
192
192
  FAISS_THROW_IF_NOT(is_trained);
193
- const float* xt = apply_chain(n, x);
194
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
193
+ TransformedVectors tv(x, apply_chain(n, x));
195
194
  index->range_search(
196
- n, xt, radius, result, extract_index_search_params(params));
195
+ n, tv.x, radius, result, extract_index_search_params(params));
197
196
  }
198
197
 
199
198
  void IndexPreTransform::reset() {
@@ -209,7 +208,7 @@ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
209
208
 
210
209
  void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
211
210
  float* x = chain.empty() ? recons : new float[index->d];
212
- ScopeDeleter<float> del(recons == x ? nullptr : x);
211
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
213
212
  // Initial reconstruction
214
213
  index->reconstruct(key, x);
215
214
 
@@ -219,7 +218,7 @@ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
219
218
 
220
219
  void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
221
220
  float* x = chain.empty() ? recons : new float[ni * index->d];
222
- ScopeDeleter<float> del(recons == x ? nullptr : x);
221
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
223
222
  // Initial reconstruction
224
223
  index->reconstruct_n(i0, ni, x);
225
224
 
@@ -238,14 +237,14 @@ void IndexPreTransform::search_and_reconstruct(
238
237
  FAISS_THROW_IF_NOT(k > 0);
239
238
  FAISS_THROW_IF_NOT(is_trained);
240
239
 
241
- const float* xt = apply_chain(n, x);
242
- ScopeDeleter<float> del((xt == x) ? nullptr : xt);
240
+ TransformedVectors trans(x, apply_chain(n, x));
243
241
 
244
242
  float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
245
- ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
243
+ std::unique_ptr<float[]> del2(
244
+ (recons_temp == recons) ? nullptr : recons_temp);
246
245
  index->search_and_reconstruct(
247
246
  n,
248
- xt,
247
+ trans.x,
249
248
  k,
250
249
  distances,
251
250
  labels,
@@ -262,13 +261,8 @@ size_t IndexPreTransform::sa_code_size() const {
262
261
 
263
262
  void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
264
263
  const {
265
- if (chain.empty()) {
266
- index->sa_encode(n, x, bytes);
267
- } else {
268
- const float* xt = apply_chain(n, x);
269
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
270
- index->sa_encode(n, xt, bytes);
271
- }
264
+ TransformedVectors tv(x, apply_chain(n, x));
265
+ index->sa_encode(n, tv.x, bytes);
272
266
  }
273
267
 
274
268
  void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
@@ -23,7 +23,7 @@ struct SearchParametersPreTransform : SearchParameters {
23
23
  /** Index that applies a LinearTransform transform on vectors before
24
24
  * handing them over to a sub-index */
25
25
  struct IndexPreTransform : Index {
26
- std::vector<VectorTransform*> chain; ///! chain of tranforms
26
+ std::vector<VectorTransform*> chain; ///! chain of transforms
27
27
  Index* index; ///! the sub-index
28
28
 
29
29
  bool own_fields; ///! whether pointers are deleted in destructor
@@ -62,18 +62,18 @@ void IndexRefine::reset() {
62
62
 
63
63
  namespace {
64
64
 
65
- typedef faiss::idx_t idx_t;
65
+ using idx_t = faiss::idx_t;
66
66
 
67
67
  template <class C>
68
68
  static void reorder_2_heaps(
69
69
  idx_t n,
70
70
  idx_t k,
71
- idx_t* labels,
72
- float* distances,
71
+ idx_t* __restrict labels,
72
+ float* __restrict distances,
73
73
  idx_t k_base,
74
- const idx_t* base_labels,
75
- const float* base_distances) {
76
- #pragma omp parallel for
74
+ const idx_t* __restrict base_labels,
75
+ const float* __restrict base_distances) {
76
+ #pragma omp parallel for if (n > 1)
77
77
  for (idx_t i = 0; i < n; i++) {
78
78
  idx_t* idxo = labels + i * k;
79
79
  float* diso = distances + i * k;
@@ -96,25 +96,40 @@ void IndexRefine::search(
96
96
  idx_t k,
97
97
  float* distances,
98
98
  idx_t* labels,
99
- const SearchParameters* params) const {
100
- FAISS_THROW_IF_NOT_MSG(
101
- !params, "search params not supported for this index");
99
+ const SearchParameters* params_in) const {
100
+ const IndexRefineSearchParameters* params = nullptr;
101
+ if (params_in) {
102
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103
+ FAISS_THROW_IF_NOT_MSG(
104
+ params, "IndexRefine params have incorrect type");
105
+ }
106
+
107
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
108
+ : idx_t(k * k_factor);
109
+ SearchParameters* base_index_params =
110
+ (params != nullptr) ? params->base_index_params : nullptr;
111
+
112
+ FAISS_THROW_IF_NOT(k_base >= k);
113
+
114
+ FAISS_THROW_IF_NOT(base_index);
115
+ FAISS_THROW_IF_NOT(refine_index);
116
+
102
117
  FAISS_THROW_IF_NOT(k > 0);
103
118
  FAISS_THROW_IF_NOT(is_trained);
104
- idx_t k_base = idx_t(k * k_factor);
105
119
  idx_t* base_labels = labels;
106
120
  float* base_distances = distances;
107
- ScopeDeleter<idx_t> del1;
108
- ScopeDeleter<float> del2;
121
+ std::unique_ptr<idx_t[]> del1;
122
+ std::unique_ptr<float[]> del2;
109
123
 
110
124
  if (k != k_base) {
111
125
  base_labels = new idx_t[n * k_base];
112
- del1.set(base_labels);
126
+ del1.reset(base_labels);
113
127
  base_distances = new float[n * k_base];
114
- del2.set(base_distances);
128
+ del2.reset(base_distances);
115
129
  }
116
130
 
117
- base_index->search(n, x, k_base, base_distances, base_labels);
131
+ base_index->search(
132
+ n, x, k_base, base_distances, base_labels, base_index_params);
118
133
 
119
134
  for (int i = 0; i < n * k_base; i++)
120
135
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,25 +240,40 @@ void IndexRefineFlat::search(
225
240
  idx_t k,
226
241
  float* distances,
227
242
  idx_t* labels,
228
- const SearchParameters* params) const {
229
- FAISS_THROW_IF_NOT_MSG(
230
- !params, "search params not supported for this index");
243
+ const SearchParameters* params_in) const {
244
+ const IndexRefineSearchParameters* params = nullptr;
245
+ if (params_in) {
246
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
247
+ FAISS_THROW_IF_NOT_MSG(
248
+ params, "IndexRefineFlat params have incorrect type");
249
+ }
250
+
251
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
252
+ : idx_t(k * k_factor);
253
+ SearchParameters* base_index_params =
254
+ (params != nullptr) ? params->base_index_params : nullptr;
255
+
256
+ FAISS_THROW_IF_NOT(k_base >= k);
257
+
258
+ FAISS_THROW_IF_NOT(base_index);
259
+ FAISS_THROW_IF_NOT(refine_index);
260
+
231
261
  FAISS_THROW_IF_NOT(k > 0);
232
262
  FAISS_THROW_IF_NOT(is_trained);
233
- idx_t k_base = idx_t(k * k_factor);
234
263
  idx_t* base_labels = labels;
235
264
  float* base_distances = distances;
236
- ScopeDeleter<idx_t> del1;
237
- ScopeDeleter<float> del2;
265
+ std::unique_ptr<idx_t[]> del1;
266
+ std::unique_ptr<float[]> del2;
238
267
 
239
268
  if (k != k_base) {
240
269
  base_labels = new idx_t[n * k_base];
241
- del1.set(base_labels);
270
+ del1.reset(base_labels);
242
271
  base_distances = new float[n * k_base];
243
- del2.set(base_distances);
272
+ del2.reset(base_distances);
244
273
  }
245
274
 
246
- base_index->search(n, x, k_base, base_distances, base_labels);
275
+ base_index->search(
276
+ n, x, k_base, base_distances, base_labels, base_index_params);
247
277
 
248
278
  for (int i = 0; i < n * k_base; i++)
249
279
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -11,6 +11,13 @@
11
11
 
12
12
  namespace faiss {
13
13
 
14
+ struct IndexRefineSearchParameters : SearchParameters {
15
+ float k_factor = 1;
16
+ SearchParameters* base_index_params = nullptr; // non-owning
17
+
18
+ virtual ~IndexRefineSearchParameters() = default;
19
+ };
20
+
14
21
  /** Index that queries in a base_index (a fast one) and refines the
15
22
  * results with an exact search, hopefully improving the results.
16
23
  */
@@ -12,17 +12,34 @@
12
12
 
13
13
  namespace faiss {
14
14
 
15
+ namespace {
16
+
17
+ // IndexBinary needs to update the code_size when d is set...
18
+
19
+ void sync_d(Index* index) {}
20
+
21
+ void sync_d(IndexBinary* index) {
22
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
23
+ index->code_size = index->d / 8;
24
+ }
25
+
26
+ } // anonymous namespace
27
+
15
28
  template <typename IndexT>
16
29
  IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
17
30
  : ThreadedIndex<IndexT>(threaded) {}
18
31
 
19
32
  template <typename IndexT>
20
33
  IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
21
- : ThreadedIndex<IndexT>(d, threaded) {}
34
+ : ThreadedIndex<IndexT>(d, threaded) {
35
+ sync_d(this);
36
+ }
22
37
 
23
38
  template <typename IndexT>
24
39
  IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
25
- : ThreadedIndex<IndexT>(d, threaded) {}
40
+ : ThreadedIndex<IndexT>(d, threaded) {
41
+ sync_d(this);
42
+ }
26
43
 
27
44
  template <typename IndexT>
28
45
  void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
@@ -168,6 +185,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
168
185
  }
169
186
 
170
187
  auto firstIndex = this->at(0);
188
+ this->d = firstIndex->d;
189
+ sync_d(this);
171
190
  this->metric_type = firstIndex->metric_type;
172
191
  this->is_trained = firstIndex->is_trained;
173
192
  this->ntotal = firstIndex->ntotal;
@@ -181,30 +200,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
181
200
  }
182
201
  }
183
202
 
184
- // No metric_type for IndexBinary
185
- template <>
186
- void IndexReplicasTemplate<IndexBinary>::syncWithSubIndexes() {
187
- if (!this->count()) {
188
- this->is_trained = false;
189
- this->ntotal = 0;
190
-
191
- return;
192
- }
193
-
194
- auto firstIndex = this->at(0);
195
- this->is_trained = firstIndex->is_trained;
196
- this->ntotal = firstIndex->ntotal;
197
-
198
- for (int i = 1; i < this->count(); ++i) {
199
- auto index = this->at(i);
200
- FAISS_THROW_IF_NOT(this->d == index->d);
201
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
202
- FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
203
- }
204
- }
205
-
206
203
  // explicit instantiations
207
- template struct IndexReplicasTemplate<Index>;
208
- template struct IndexReplicasTemplate<IndexBinary>;
204
+ template class IndexReplicasTemplate<Index>;
205
+ template class IndexReplicasTemplate<IndexBinary>;
209
206
 
210
207
  } // namespace faiss
@@ -32,7 +32,9 @@ IndexScalarQuantizer::IndexScalarQuantizer(
32
32
  MetricType metric)
33
33
  : IndexFlatCodes(0, d, metric), sq(d, qtype) {
34
34
  is_trained = qtype == ScalarQuantizer::QT_fp16 ||
35
- qtype == ScalarQuantizer::QT_8bit_direct;
35
+ qtype == ScalarQuantizer::QT_8bit_direct ||
36
+ qtype == ScalarQuantizer::QT_bf16 ||
37
+ qtype == ScalarQuantizer::QT_8bit_direct_signed;
36
38
  code_size = sq.code_size;
37
39
  }
38
40
 
@@ -60,10 +62,9 @@ void IndexScalarQuantizer::search(
60
62
 
61
63
  #pragma omp parallel
62
64
  {
63
- InvertedListScanner* scanner =
64
- sq.select_InvertedListScanner(metric_type, nullptr, true, sel);
65
+ std::unique_ptr<InvertedListScanner> scanner(
66
+ sq.select_InvertedListScanner(metric_type, nullptr, true, sel));
65
67
 
66
- ScopeDeleter1<InvertedListScanner> del(scanner);
67
68
  scanner->list_no = 0; // directly the list number
68
69
 
69
70
  #pragma omp for
@@ -122,21 +123,28 @@ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
122
123
  size_t nlist,
123
124
  ScalarQuantizer::QuantizerType qtype,
124
125
  MetricType metric,
125
- bool encode_residual)
126
- : IndexIVF(quantizer, d, nlist, 0, metric),
127
- sq(d, qtype),
128
- by_residual(encode_residual) {
126
+ bool by_residual)
127
+ : IndexIVF(quantizer, d, nlist, 0, metric), sq(d, qtype) {
129
128
  code_size = sq.code_size;
129
+ this->by_residual = by_residual;
130
130
  // was not known at construction time
131
131
  invlists->code_size = code_size;
132
132
  is_trained = false;
133
133
  }
134
134
 
135
- IndexIVFScalarQuantizer::IndexIVFScalarQuantizer()
136
- : IndexIVF(), by_residual(true) {}
135
+ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer() : IndexIVF() {
136
+ by_residual = true;
137
+ }
137
138
 
138
- void IndexIVFScalarQuantizer::train_residual(idx_t n, const float* x) {
139
- sq.train_residual(n, x, quantizer, by_residual, verbose);
139
+ void IndexIVFScalarQuantizer::train_encoder(
140
+ idx_t n,
141
+ const float* x,
142
+ const idx_t* assign) {
143
+ sq.train(n, x);
144
+ }
145
+
146
+ idx_t IndexIVFScalarQuantizer::train_encoder_num_vectors() const {
147
+ return 100000;
140
148
  }
141
149
 
142
150
  void IndexIVFScalarQuantizer::encode_vectors(
@@ -201,15 +209,15 @@ void IndexIVFScalarQuantizer::add_core(
201
209
  idx_t n,
202
210
  const float* x,
203
211
  const idx_t* xids,
204
- const idx_t* coarse_idx) {
212
+ const idx_t* coarse_idx,
213
+ void* inverted_list_context) {
205
214
  FAISS_THROW_IF_NOT(is_trained);
206
215
 
207
- size_t nadd = 0;
208
216
  std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
209
217
 
210
218
  DirectMapAdd dm_add(direct_map, n, xids);
211
219
 
212
- #pragma omp parallel reduction(+ : nadd)
220
+ #pragma omp parallel
213
221
  {
214
222
  std::vector<float> residual(d);
215
223
  std::vector<uint8_t> one_code(code_size);
@@ -231,10 +239,10 @@ void IndexIVFScalarQuantizer::add_core(
231
239
  memset(one_code.data(), 0, code_size);
232
240
  squant->encode_vector(xi, one_code.data());
233
241
 
234
- size_t ofs = invlists->add_entry(list_no, id, one_code.data());
242
+ size_t ofs = invlists->add_entry(
243
+ list_no, id, one_code.data(), inverted_list_context);
235
244
 
236
245
  dm_add.add(i, list_no, ofs);
237
- nadd++;
238
246
 
239
247
  } else if (rank == 0 && list_no == -1) {
240
248
  dm_add.add(i, -1, 0);
@@ -65,7 +65,6 @@ struct IndexScalarQuantizer : IndexFlatCodes {
65
65
 
66
66
  struct IndexIVFScalarQuantizer : IndexIVF {
67
67
  ScalarQuantizer sq;
68
- bool by_residual;
69
68
 
70
69
  IndexIVFScalarQuantizer(
71
70
  Index* quantizer,
@@ -73,11 +72,13 @@ struct IndexIVFScalarQuantizer : IndexIVF {
73
72
  size_t nlist,
74
73
  ScalarQuantizer::QuantizerType qtype,
75
74
  MetricType metric = METRIC_L2,
76
- bool encode_residual = true);
75
+ bool by_residual = true);
77
76
 
78
77
  IndexIVFScalarQuantizer();
79
78
 
80
- void train_residual(idx_t n, const float* x) override;
79
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
80
+
81
+ idx_t train_encoder_num_vectors() const override;
81
82
 
82
83
  void encode_vectors(
83
84
  idx_t n,
@@ -90,7 +91,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
90
91
  idx_t n,
91
92
  const float* x,
92
93
  const idx_t* xids,
93
- const idx_t* precomputed_idx) override;
94
+ const idx_t* precomputed_idx,
95
+ void* inverted_list_context = nullptr) override;
94
96
 
95
97
  InvertedListScanner* get_InvertedListScanner(
96
98
  bool store_pairs,
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexShards.h>
11
9
 
12
10
  #include <cinttypes>
@@ -22,6 +20,15 @@ namespace faiss {
22
20
  // subroutines
23
21
  namespace {
24
22
 
23
+ // IndexBinary needs to update the code_size when d is set...
24
+
25
+ void sync_d(Index* index) {}
26
+
27
+ void sync_d(IndexBinary* index) {
28
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
29
+ index->code_size = index->d / 8;
30
+ }
31
+
25
32
  // add translation to all valid labels
26
33
  void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
27
34
  if (translation == 0)
@@ -40,20 +47,26 @@ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
40
47
  idx_t d,
41
48
  bool threaded,
42
49
  bool successive_ids)
43
- : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
50
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
51
+ sync_d(this);
52
+ }
44
53
 
45
54
  template <typename IndexT>
46
55
  IndexShardsTemplate<IndexT>::IndexShardsTemplate(
47
56
  int d,
48
57
  bool threaded,
49
58
  bool successive_ids)
50
- : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
59
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
60
+ sync_d(this);
61
+ }
51
62
 
52
63
  template <typename IndexT>
53
64
  IndexShardsTemplate<IndexT>::IndexShardsTemplate(
54
65
  bool threaded,
55
66
  bool successive_ids)
56
- : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
67
+ : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {
68
+ sync_d(this);
69
+ }
57
70
 
58
71
  template <typename IndexT>
59
72
  void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
@@ -78,6 +91,8 @@ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
78
91
  }
79
92
 
80
93
  auto firstIndex = this->at(0);
94
+ this->d = firstIndex->d;
95
+ sync_d(this);
81
96
  this->metric_type = firstIndex->metric_type;
82
97
  this->is_trained = firstIndex->is_trained;
83
98
  this->ntotal = firstIndex->ntotal;
@@ -92,29 +107,6 @@ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
92
107
  }
93
108
  }
94
109
 
95
- // No metric_type for IndexBinary
96
- template <>
97
- void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
98
- if (!this->count()) {
99
- this->is_trained = false;
100
- this->ntotal = 0;
101
-
102
- return;
103
- }
104
-
105
- auto firstIndex = this->at(0);
106
- this->is_trained = firstIndex->is_trained;
107
- this->ntotal = firstIndex->ntotal;
108
-
109
- for (int i = 1; i < this->count(); ++i) {
110
- auto index = this->at(i);
111
- FAISS_THROW_IF_NOT(this->d == index->d);
112
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
113
-
114
- this->ntotal += index->ntotal;
115
- }
116
- }
117
-
118
110
  template <typename IndexT>
119
111
  void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
120
112
  auto fn = [n, x](int no, IndexT* index) {
@@ -155,7 +147,7 @@ void IndexShardsTemplate<IndexT>::add_with_ids(
155
147
  "request them to be shifted");
156
148
  FAISS_THROW_IF_NOT_MSG(
157
149
  this->ntotal == 0,
158
- "when adding to IndexShards with sucessive_ids, "
150
+ "when adding to IndexShards with successive_ids, "
159
151
  "only add() in a single pass is supported");
160
152
  }
161
153
 
@@ -111,7 +111,7 @@ void IndexShardsIVF::add_with_ids(
111
111
  "request them to be shifted");
112
112
  FAISS_THROW_IF_NOT_MSG(
113
113
  this->ntotal == 0,
114
- "when adding to IndexShards with sucessive_ids, "
114
+ "when adding to IndexShards with successive_ids, "
115
115
  "only add() in a single pass is supported");
116
116
  }
117
117
 
@@ -137,7 +137,6 @@ void IndexShardsIVF::add_with_ids(
137
137
  auto fn = [n, ids, x, nshard, d, Iq](int no, Index* index) {
138
138
  idx_t i0 = (idx_t)no * n / nshard;
139
139
  idx_t i1 = ((idx_t)no + 1) * n / nshard;
140
- const float* x0 = x + i0 * d;
141
140
  auto index_ivf = dynamic_cast<IndexIVF*>(index);
142
141
 
143
142
  if (index->verbose) {