faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -67,7 +67,7 @@ void IndexPreTransform::train(idx_t n, const float* x) {
67
67
  }
68
68
  }
69
69
  const float* prev_x = x;
70
- ScopeDeleter<float> del;
70
+ std::unique_ptr<const float[]> del;
71
71
 
72
72
  if (verbose) {
73
73
  printf("IndexPreTransform::train: training chain 0 to %d\n",
@@ -102,10 +102,12 @@ void IndexPreTransform::train(idx_t n, const float* x) {
102
102
 
103
103
  float* xt = chain[i]->apply(n, prev_x);
104
104
 
105
- if (prev_x != x)
106
- delete[] prev_x;
105
+ if (prev_x != x) {
106
+ del.reset();
107
+ }
108
+
107
109
  prev_x = xt;
108
- del.set(xt);
110
+ del.reset(xt);
109
111
  }
110
112
 
111
113
  is_trained = true;
@@ -113,11 +115,11 @@ void IndexPreTransform::train(idx_t n, const float* x) {
113
115
 
114
116
  const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
115
117
  const float* prev_x = x;
116
- ScopeDeleter<float> del;
118
+ std::unique_ptr<const float[]> del;
117
119
 
118
120
  for (int i = 0; i < chain.size(); i++) {
119
121
  float* xt = chain[i]->apply(n, prev_x);
120
- ScopeDeleter<float> del2(xt);
122
+ std::unique_ptr<const float[]> del2(xt);
121
123
  del2.swap(del);
122
124
  prev_x = xt;
123
125
  }
@@ -128,11 +130,11 @@ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
128
130
  void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
129
131
  const {
130
132
  const float* next_x = xt;
131
- ScopeDeleter<float> del;
133
+ std::unique_ptr<const float[]> del;
132
134
 
133
135
  for (int i = chain.size() - 1; i >= 0; i--) {
134
136
  float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
135
- ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
137
+ std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
136
138
  chain[i]->reverse_transform(n, next_x, prev_x);
137
139
  del2.swap(del);
138
140
  next_x = prev_x;
@@ -141,9 +143,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
141
143
 
142
144
  void IndexPreTransform::add(idx_t n, const float* x) {
143
145
  FAISS_THROW_IF_NOT(is_trained);
144
- const float* xt = apply_chain(n, x);
145
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
146
- index->add(n, xt);
146
+ TransformedVectors tv(x, apply_chain(n, x));
147
+ index->add(n, tv.x);
147
148
  ntotal = index->ntotal;
148
149
  }
149
150
 
@@ -152,9 +153,8 @@ void IndexPreTransform::add_with_ids(
152
153
  const float* x,
153
154
  const idx_t* xids) {
154
155
  FAISS_THROW_IF_NOT(is_trained);
155
- const float* xt = apply_chain(n, x);
156
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
157
- index->add_with_ids(n, xt, xids);
156
+ TransformedVectors tv(x, apply_chain(n, x));
157
+ index->add_with_ids(n, tv.x, xids);
158
158
  ntotal = index->ntotal;
159
159
  }
160
160
 
@@ -178,7 +178,7 @@ void IndexPreTransform::search(
178
178
  FAISS_THROW_IF_NOT(k > 0);
179
179
  FAISS_THROW_IF_NOT(is_trained);
180
180
  const float* xt = apply_chain(n, x);
181
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
181
+ std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
182
182
  index->search(
183
183
  n, xt, k, distances, labels, extract_index_search_params(params));
184
184
  }
@@ -190,10 +190,9 @@ void IndexPreTransform::range_search(
190
190
  RangeSearchResult* result,
191
191
  const SearchParameters* params) const {
192
192
  FAISS_THROW_IF_NOT(is_trained);
193
- const float* xt = apply_chain(n, x);
194
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
193
+ TransformedVectors tv(x, apply_chain(n, x));
195
194
  index->range_search(
196
- n, xt, radius, result, extract_index_search_params(params));
195
+ n, tv.x, radius, result, extract_index_search_params(params));
197
196
  }
198
197
 
199
198
  void IndexPreTransform::reset() {
@@ -209,7 +208,7 @@ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
209
208
 
210
209
  void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
211
210
  float* x = chain.empty() ? recons : new float[index->d];
212
- ScopeDeleter<float> del(recons == x ? nullptr : x);
211
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
213
212
  // Initial reconstruction
214
213
  index->reconstruct(key, x);
215
214
 
@@ -219,7 +218,7 @@ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
219
218
 
220
219
  void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
221
220
  float* x = chain.empty() ? recons : new float[ni * index->d];
222
- ScopeDeleter<float> del(recons == x ? nullptr : x);
221
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
223
222
  // Initial reconstruction
224
223
  index->reconstruct_n(i0, ni, x);
225
224
 
@@ -238,14 +237,14 @@ void IndexPreTransform::search_and_reconstruct(
238
237
  FAISS_THROW_IF_NOT(k > 0);
239
238
  FAISS_THROW_IF_NOT(is_trained);
240
239
 
241
- const float* xt = apply_chain(n, x);
242
- ScopeDeleter<float> del((xt == x) ? nullptr : xt);
240
+ TransformedVectors trans(x, apply_chain(n, x));
243
241
 
244
242
  float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
245
- ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
243
+ std::unique_ptr<float[]> del2(
244
+ (recons_temp == recons) ? nullptr : recons_temp);
246
245
  index->search_and_reconstruct(
247
246
  n,
248
- xt,
247
+ trans.x,
249
248
  k,
250
249
  distances,
251
250
  labels,
@@ -262,13 +261,8 @@ size_t IndexPreTransform::sa_code_size() const {
262
261
 
263
262
  void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
264
263
  const {
265
- if (chain.empty()) {
266
- index->sa_encode(n, x, bytes);
267
- } else {
268
- const float* xt = apply_chain(n, x);
269
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
270
- index->sa_encode(n, xt, bytes);
271
- }
264
+ TransformedVectors tv(x, apply_chain(n, x));
265
+ index->sa_encode(n, tv.x, bytes);
272
266
  }
273
267
 
274
268
  void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
@@ -23,7 +23,7 @@ struct SearchParametersPreTransform : SearchParameters {
23
23
  /** Index that applies a LinearTransform transform on vectors before
24
24
  * handing them over to a sub-index */
25
25
  struct IndexPreTransform : Index {
26
- std::vector<VectorTransform*> chain; ///! chain of tranforms
26
+ std::vector<VectorTransform*> chain; ///! chain of transforms
27
27
  Index* index; ///! the sub-index
28
28
 
29
29
  bool own_fields; ///! whether pointers are deleted in destructor
@@ -62,18 +62,18 @@ void IndexRefine::reset() {
62
62
 
63
63
  namespace {
64
64
 
65
- typedef faiss::idx_t idx_t;
65
+ using idx_t = faiss::idx_t;
66
66
 
67
67
  template <class C>
68
68
  static void reorder_2_heaps(
69
69
  idx_t n,
70
70
  idx_t k,
71
- idx_t* labels,
72
- float* distances,
71
+ idx_t* __restrict labels,
72
+ float* __restrict distances,
73
73
  idx_t k_base,
74
- const idx_t* base_labels,
75
- const float* base_distances) {
76
- #pragma omp parallel for
74
+ const idx_t* __restrict base_labels,
75
+ const float* __restrict base_distances) {
76
+ #pragma omp parallel for if (n > 1)
77
77
  for (idx_t i = 0; i < n; i++) {
78
78
  idx_t* idxo = labels + i * k;
79
79
  float* diso = distances + i * k;
@@ -96,25 +96,40 @@ void IndexRefine::search(
96
96
  idx_t k,
97
97
  float* distances,
98
98
  idx_t* labels,
99
- const SearchParameters* params) const {
100
- FAISS_THROW_IF_NOT_MSG(
101
- !params, "search params not supported for this index");
99
+ const SearchParameters* params_in) const {
100
+ const IndexRefineSearchParameters* params = nullptr;
101
+ if (params_in) {
102
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103
+ FAISS_THROW_IF_NOT_MSG(
104
+ params, "IndexRefine params have incorrect type");
105
+ }
106
+
107
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
108
+ : idx_t(k * k_factor);
109
+ SearchParameters* base_index_params =
110
+ (params != nullptr) ? params->base_index_params : nullptr;
111
+
112
+ FAISS_THROW_IF_NOT(k_base >= k);
113
+
114
+ FAISS_THROW_IF_NOT(base_index);
115
+ FAISS_THROW_IF_NOT(refine_index);
116
+
102
117
  FAISS_THROW_IF_NOT(k > 0);
103
118
  FAISS_THROW_IF_NOT(is_trained);
104
- idx_t k_base = idx_t(k * k_factor);
105
119
  idx_t* base_labels = labels;
106
120
  float* base_distances = distances;
107
- ScopeDeleter<idx_t> del1;
108
- ScopeDeleter<float> del2;
121
+ std::unique_ptr<idx_t[]> del1;
122
+ std::unique_ptr<float[]> del2;
109
123
 
110
124
  if (k != k_base) {
111
125
  base_labels = new idx_t[n * k_base];
112
- del1.set(base_labels);
126
+ del1.reset(base_labels);
113
127
  base_distances = new float[n * k_base];
114
- del2.set(base_distances);
128
+ del2.reset(base_distances);
115
129
  }
116
130
 
117
- base_index->search(n, x, k_base, base_distances, base_labels);
131
+ base_index->search(
132
+ n, x, k_base, base_distances, base_labels, base_index_params);
118
133
 
119
134
  for (int i = 0; i < n * k_base; i++)
120
135
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,25 +240,40 @@ void IndexRefineFlat::search(
225
240
  idx_t k,
226
241
  float* distances,
227
242
  idx_t* labels,
228
- const SearchParameters* params) const {
229
- FAISS_THROW_IF_NOT_MSG(
230
- !params, "search params not supported for this index");
243
+ const SearchParameters* params_in) const {
244
+ const IndexRefineSearchParameters* params = nullptr;
245
+ if (params_in) {
246
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
247
+ FAISS_THROW_IF_NOT_MSG(
248
+ params, "IndexRefineFlat params have incorrect type");
249
+ }
250
+
251
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
252
+ : idx_t(k * k_factor);
253
+ SearchParameters* base_index_params =
254
+ (params != nullptr) ? params->base_index_params : nullptr;
255
+
256
+ FAISS_THROW_IF_NOT(k_base >= k);
257
+
258
+ FAISS_THROW_IF_NOT(base_index);
259
+ FAISS_THROW_IF_NOT(refine_index);
260
+
231
261
  FAISS_THROW_IF_NOT(k > 0);
232
262
  FAISS_THROW_IF_NOT(is_trained);
233
- idx_t k_base = idx_t(k * k_factor);
234
263
  idx_t* base_labels = labels;
235
264
  float* base_distances = distances;
236
- ScopeDeleter<idx_t> del1;
237
- ScopeDeleter<float> del2;
265
+ std::unique_ptr<idx_t[]> del1;
266
+ std::unique_ptr<float[]> del2;
238
267
 
239
268
  if (k != k_base) {
240
269
  base_labels = new idx_t[n * k_base];
241
- del1.set(base_labels);
270
+ del1.reset(base_labels);
242
271
  base_distances = new float[n * k_base];
243
- del2.set(base_distances);
272
+ del2.reset(base_distances);
244
273
  }
245
274
 
246
- base_index->search(n, x, k_base, base_distances, base_labels);
275
+ base_index->search(
276
+ n, x, k_base, base_distances, base_labels, base_index_params);
247
277
 
248
278
  for (int i = 0; i < n * k_base; i++)
249
279
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -11,6 +11,13 @@
11
11
 
12
12
  namespace faiss {
13
13
 
14
+ struct IndexRefineSearchParameters : SearchParameters {
15
+ float k_factor = 1;
16
+ SearchParameters* base_index_params = nullptr; // non-owning
17
+
18
+ virtual ~IndexRefineSearchParameters() = default;
19
+ };
20
+
14
21
  /** Index that queries in a base_index (a fast one) and refines the
15
22
  * results with an exact search, hopefully improving the results.
16
23
  */
@@ -12,17 +12,34 @@
12
12
 
13
13
  namespace faiss {
14
14
 
15
+ namespace {
16
+
17
+ // IndexBinary needs to update the code_size when d is set...
18
+
19
+ void sync_d(Index* index) {}
20
+
21
+ void sync_d(IndexBinary* index) {
22
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
23
+ index->code_size = index->d / 8;
24
+ }
25
+
26
+ } // anonymous namespace
27
+
15
28
  template <typename IndexT>
16
29
  IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
17
30
  : ThreadedIndex<IndexT>(threaded) {}
18
31
 
19
32
  template <typename IndexT>
20
33
  IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
21
- : ThreadedIndex<IndexT>(d, threaded) {}
34
+ : ThreadedIndex<IndexT>(d, threaded) {
35
+ sync_d(this);
36
+ }
22
37
 
23
38
  template <typename IndexT>
24
39
  IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
25
- : ThreadedIndex<IndexT>(d, threaded) {}
40
+ : ThreadedIndex<IndexT>(d, threaded) {
41
+ sync_d(this);
42
+ }
26
43
 
27
44
  template <typename IndexT>
28
45
  void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
@@ -168,6 +185,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
168
185
  }
169
186
 
170
187
  auto firstIndex = this->at(0);
188
+ this->d = firstIndex->d;
189
+ sync_d(this);
171
190
  this->metric_type = firstIndex->metric_type;
172
191
  this->is_trained = firstIndex->is_trained;
173
192
  this->ntotal = firstIndex->ntotal;
@@ -181,30 +200,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
181
200
  }
182
201
  }
183
202
 
184
- // No metric_type for IndexBinary
185
- template <>
186
- void IndexReplicasTemplate<IndexBinary>::syncWithSubIndexes() {
187
- if (!this->count()) {
188
- this->is_trained = false;
189
- this->ntotal = 0;
190
-
191
- return;
192
- }
193
-
194
- auto firstIndex = this->at(0);
195
- this->is_trained = firstIndex->is_trained;
196
- this->ntotal = firstIndex->ntotal;
197
-
198
- for (int i = 1; i < this->count(); ++i) {
199
- auto index = this->at(i);
200
- FAISS_THROW_IF_NOT(this->d == index->d);
201
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
202
- FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
203
- }
204
- }
205
-
206
203
  // explicit instantiations
207
- template struct IndexReplicasTemplate<Index>;
208
- template struct IndexReplicasTemplate<IndexBinary>;
204
+ template class IndexReplicasTemplate<Index>;
205
+ template class IndexReplicasTemplate<IndexBinary>;
209
206
 
210
207
  } // namespace faiss
@@ -32,7 +32,9 @@ IndexScalarQuantizer::IndexScalarQuantizer(
32
32
  MetricType metric)
33
33
  : IndexFlatCodes(0, d, metric), sq(d, qtype) {
34
34
  is_trained = qtype == ScalarQuantizer::QT_fp16 ||
35
- qtype == ScalarQuantizer::QT_8bit_direct;
35
+ qtype == ScalarQuantizer::QT_8bit_direct ||
36
+ qtype == ScalarQuantizer::QT_bf16 ||
37
+ qtype == ScalarQuantizer::QT_8bit_direct_signed;
36
38
  code_size = sq.code_size;
37
39
  }
38
40
 
@@ -60,10 +62,9 @@ void IndexScalarQuantizer::search(
60
62
 
61
63
  #pragma omp parallel
62
64
  {
63
- InvertedListScanner* scanner =
64
- sq.select_InvertedListScanner(metric_type, nullptr, true, sel);
65
+ std::unique_ptr<InvertedListScanner> scanner(
66
+ sq.select_InvertedListScanner(metric_type, nullptr, true, sel));
65
67
 
66
- ScopeDeleter1<InvertedListScanner> del(scanner);
67
68
  scanner->list_no = 0; // directly the list number
68
69
 
69
70
  #pragma omp for
@@ -122,21 +123,28 @@ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
122
123
  size_t nlist,
123
124
  ScalarQuantizer::QuantizerType qtype,
124
125
  MetricType metric,
125
- bool encode_residual)
126
- : IndexIVF(quantizer, d, nlist, 0, metric),
127
- sq(d, qtype),
128
- by_residual(encode_residual) {
126
+ bool by_residual)
127
+ : IndexIVF(quantizer, d, nlist, 0, metric), sq(d, qtype) {
129
128
  code_size = sq.code_size;
129
+ this->by_residual = by_residual;
130
130
  // was not known at construction time
131
131
  invlists->code_size = code_size;
132
132
  is_trained = false;
133
133
  }
134
134
 
135
- IndexIVFScalarQuantizer::IndexIVFScalarQuantizer()
136
- : IndexIVF(), by_residual(true) {}
135
+ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer() : IndexIVF() {
136
+ by_residual = true;
137
+ }
137
138
 
138
- void IndexIVFScalarQuantizer::train_residual(idx_t n, const float* x) {
139
- sq.train_residual(n, x, quantizer, by_residual, verbose);
139
+ void IndexIVFScalarQuantizer::train_encoder(
140
+ idx_t n,
141
+ const float* x,
142
+ const idx_t* assign) {
143
+ sq.train(n, x);
144
+ }
145
+
146
+ idx_t IndexIVFScalarQuantizer::train_encoder_num_vectors() const {
147
+ return 100000;
140
148
  }
141
149
 
142
150
  void IndexIVFScalarQuantizer::encode_vectors(
@@ -201,15 +209,15 @@ void IndexIVFScalarQuantizer::add_core(
201
209
  idx_t n,
202
210
  const float* x,
203
211
  const idx_t* xids,
204
- const idx_t* coarse_idx) {
212
+ const idx_t* coarse_idx,
213
+ void* inverted_list_context) {
205
214
  FAISS_THROW_IF_NOT(is_trained);
206
215
 
207
- size_t nadd = 0;
208
216
  std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
209
217
 
210
218
  DirectMapAdd dm_add(direct_map, n, xids);
211
219
 
212
- #pragma omp parallel reduction(+ : nadd)
220
+ #pragma omp parallel
213
221
  {
214
222
  std::vector<float> residual(d);
215
223
  std::vector<uint8_t> one_code(code_size);
@@ -231,10 +239,10 @@ void IndexIVFScalarQuantizer::add_core(
231
239
  memset(one_code.data(), 0, code_size);
232
240
  squant->encode_vector(xi, one_code.data());
233
241
 
234
- size_t ofs = invlists->add_entry(list_no, id, one_code.data());
242
+ size_t ofs = invlists->add_entry(
243
+ list_no, id, one_code.data(), inverted_list_context);
235
244
 
236
245
  dm_add.add(i, list_no, ofs);
237
- nadd++;
238
246
 
239
247
  } else if (rank == 0 && list_no == -1) {
240
248
  dm_add.add(i, -1, 0);
@@ -65,7 +65,6 @@ struct IndexScalarQuantizer : IndexFlatCodes {
65
65
 
66
66
  struct IndexIVFScalarQuantizer : IndexIVF {
67
67
  ScalarQuantizer sq;
68
- bool by_residual;
69
68
 
70
69
  IndexIVFScalarQuantizer(
71
70
  Index* quantizer,
@@ -73,11 +72,13 @@ struct IndexIVFScalarQuantizer : IndexIVF {
73
72
  size_t nlist,
74
73
  ScalarQuantizer::QuantizerType qtype,
75
74
  MetricType metric = METRIC_L2,
76
- bool encode_residual = true);
75
+ bool by_residual = true);
77
76
 
78
77
  IndexIVFScalarQuantizer();
79
78
 
80
- void train_residual(idx_t n, const float* x) override;
79
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
80
+
81
+ idx_t train_encoder_num_vectors() const override;
81
82
 
82
83
  void encode_vectors(
83
84
  idx_t n,
@@ -90,7 +91,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
90
91
  idx_t n,
91
92
  const float* x,
92
93
  const idx_t* xids,
93
- const idx_t* precomputed_idx) override;
94
+ const idx_t* precomputed_idx,
95
+ void* inverted_list_context = nullptr) override;
94
96
 
95
97
  InvertedListScanner* get_InvertedListScanner(
96
98
  bool store_pairs,
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexShards.h>
11
9
 
12
10
  #include <cinttypes>
@@ -22,6 +20,15 @@ namespace faiss {
22
20
  // subroutines
23
21
  namespace {
24
22
 
23
+ // IndexBinary needs to update the code_size when d is set...
24
+
25
+ void sync_d(Index* index) {}
26
+
27
+ void sync_d(IndexBinary* index) {
28
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
29
+ index->code_size = index->d / 8;
30
+ }
31
+
25
32
  // add translation to all valid labels
26
33
  void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
27
34
  if (translation == 0)
@@ -40,20 +47,26 @@ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
40
47
  idx_t d,
41
48
  bool threaded,
42
49
  bool successive_ids)
43
- : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
50
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
51
+ sync_d(this);
52
+ }
44
53
 
45
54
  template <typename IndexT>
46
55
  IndexShardsTemplate<IndexT>::IndexShardsTemplate(
47
56
  int d,
48
57
  bool threaded,
49
58
  bool successive_ids)
50
- : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
59
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
60
+ sync_d(this);
61
+ }
51
62
 
52
63
  template <typename IndexT>
53
64
  IndexShardsTemplate<IndexT>::IndexShardsTemplate(
54
65
  bool threaded,
55
66
  bool successive_ids)
56
- : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
67
+ : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {
68
+ sync_d(this);
69
+ }
57
70
 
58
71
  template <typename IndexT>
59
72
  void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
@@ -78,6 +91,8 @@ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
78
91
  }
79
92
 
80
93
  auto firstIndex = this->at(0);
94
+ this->d = firstIndex->d;
95
+ sync_d(this);
81
96
  this->metric_type = firstIndex->metric_type;
82
97
  this->is_trained = firstIndex->is_trained;
83
98
  this->ntotal = firstIndex->ntotal;
@@ -92,29 +107,6 @@ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
92
107
  }
93
108
  }
94
109
 
95
- // No metric_type for IndexBinary
96
- template <>
97
- void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
98
- if (!this->count()) {
99
- this->is_trained = false;
100
- this->ntotal = 0;
101
-
102
- return;
103
- }
104
-
105
- auto firstIndex = this->at(0);
106
- this->is_trained = firstIndex->is_trained;
107
- this->ntotal = firstIndex->ntotal;
108
-
109
- for (int i = 1; i < this->count(); ++i) {
110
- auto index = this->at(i);
111
- FAISS_THROW_IF_NOT(this->d == index->d);
112
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
113
-
114
- this->ntotal += index->ntotal;
115
- }
116
- }
117
-
118
110
  template <typename IndexT>
119
111
  void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
120
112
  auto fn = [n, x](int no, IndexT* index) {
@@ -155,7 +147,7 @@ void IndexShardsTemplate<IndexT>::add_with_ids(
155
147
  "request them to be shifted");
156
148
  FAISS_THROW_IF_NOT_MSG(
157
149
  this->ntotal == 0,
158
- "when adding to IndexShards with sucessive_ids, "
150
+ "when adding to IndexShards with successive_ids, "
159
151
  "only add() in a single pass is supported");
160
152
  }
161
153
 
@@ -111,7 +111,7 @@ void IndexShardsIVF::add_with_ids(
111
111
  "request them to be shifted");
112
112
  FAISS_THROW_IF_NOT_MSG(
113
113
  this->ntotal == 0,
114
- "when adding to IndexShards with sucessive_ids, "
114
+ "when adding to IndexShards with successive_ids, "
115
115
  "only add() in a single pass is supported");
116
116
  }
117
117
 
@@ -137,7 +137,6 @@ void IndexShardsIVF::add_with_ids(
137
137
  auto fn = [n, ids, x, nshard, d, Iq](int no, Index* index) {
138
138
  idx_t i0 = (idx_t)no * n / nshard;
139
139
  idx_t i1 = ((idx_t)no + 1) * n / nshard;
140
- const float* x0 = x + i0 * d;
141
140
  auto index_ivf = dynamic_cast<IndexIVF*>(index);
142
141
 
143
142
  if (index->verbose) {