faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -1,3 +1,10 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
1
8
  #pragma once
2
9
 
3
10
  #include <cstdint>
@@ -22,98 +22,17 @@ namespace faiss {
22
22
  // subroutines
23
23
  namespace {
24
24
 
25
- typedef Index::idx_t idx_t;
26
-
27
25
  // add translation to all valid labels
28
- void translate_labels(long n, idx_t* labels, long translation) {
26
+ void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
29
27
  if (translation == 0)
30
28
  return;
31
- for (long i = 0; i < n; i++) {
29
+ for (int64_t i = 0; i < n; i++) {
32
30
  if (labels[i] < 0)
33
31
  continue;
34
32
  labels[i] += translation;
35
33
  }
36
34
  }
37
35
 
38
- /** merge result tables from several shards.
39
- * @param all_distances size nshard * n * k
40
- * @param all_labels idem
41
- * @param translartions label translations to apply, size nshard
42
- */
43
-
44
- template <class IndexClass, class C>
45
- void merge_tables(
46
- long n,
47
- long k,
48
- long nshard,
49
- typename IndexClass::distance_t* distances,
50
- idx_t* labels,
51
- const std::vector<typename IndexClass::distance_t>& all_distances,
52
- const std::vector<idx_t>& all_labels,
53
- const std::vector<long>& translations) {
54
- if (k == 0) {
55
- return;
56
- }
57
- using distance_t = typename IndexClass::distance_t;
58
-
59
- long stride = n * k;
60
- #pragma omp parallel
61
- {
62
- std::vector<int> buf(2 * nshard);
63
- int* pointer = buf.data();
64
- int* shard_ids = pointer + nshard;
65
- std::vector<distance_t> buf2(nshard);
66
- distance_t* heap_vals = buf2.data();
67
- #pragma omp for
68
- for (long i = 0; i < n; i++) {
69
- // the heap maps values to the shard where they are
70
- // produced.
71
- const distance_t* D_in = all_distances.data() + i * k;
72
- const idx_t* I_in = all_labels.data() + i * k;
73
- int heap_size = 0;
74
-
75
- for (long s = 0; s < nshard; s++) {
76
- pointer[s] = 0;
77
- if (I_in[stride * s] >= 0) {
78
- heap_push<C>(
79
- ++heap_size,
80
- heap_vals,
81
- shard_ids,
82
- D_in[stride * s],
83
- s);
84
- }
85
- }
86
-
87
- distance_t* D = distances + i * k;
88
- idx_t* I = labels + i * k;
89
-
90
- for (int j = 0; j < k; j++) {
91
- if (heap_size == 0) {
92
- I[j] = -1;
93
- D[j] = C::neutral();
94
- } else {
95
- // pop best element
96
- int s = shard_ids[0];
97
- int& p = pointer[s];
98
- D[j] = heap_vals[0];
99
- I[j] = I_in[stride * s + p] + translations[s];
100
-
101
- heap_pop<C>(heap_size--, heap_vals, shard_ids);
102
- p++;
103
- if (p < k && I_in[stride * s + p] >= 0) {
104
- heap_push<C>(
105
- ++heap_size,
106
- heap_vals,
107
- shard_ids,
108
- D_in[stride * s + p],
109
- s);
110
- }
111
- }
112
- }
113
- }
114
- }
115
- }
116
-
117
36
  } // anonymous namespace
118
37
 
119
38
  template <typename IndexT>
@@ -247,11 +166,9 @@ void IndexShardsTemplate<IndexT>::add_with_ids(
247
166
 
248
167
  if (!ids && !successive_ids) {
249
168
  aids.resize(n);
250
-
251
169
  for (idx_t i = 0; i < n; i++) {
252
170
  aids[i] = this->ntotal + i;
253
171
  }
254
-
255
172
  ids = aids.data();
256
173
  }
257
174
 
@@ -294,12 +211,23 @@ void IndexShardsTemplate<IndexT>::search(
294
211
  !params, "search params not supported for this index");
295
212
  FAISS_THROW_IF_NOT(k > 0);
296
213
 
297
- long nshard = this->count();
214
+ int64_t nshard = this->count();
298
215
 
299
216
  std::vector<distance_t> all_distances(nshard * k * n);
300
217
  std::vector<idx_t> all_labels(nshard * k * n);
218
+ std::vector<int64_t> translations(nshard, 0);
219
+
220
+ // Because we just called runOnIndex above, it is safe to access the
221
+ // sub-index ntotal here
222
+ if (successive_ids) {
223
+ translations[0] = 0;
301
224
 
302
- auto fn = [n, k, x, &all_distances, &all_labels](
225
+ for (int s = 0; s + 1 < nshard; s++) {
226
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
227
+ }
228
+ }
229
+
230
+ auto fn = [n, k, x, &all_distances, &all_labels, &translations](
303
231
  int no, const IndexT* index) {
304
232
  if (index->verbose) {
305
233
  printf("begin query shard %d on %" PRId64 " points\n", no, n);
@@ -312,6 +240,9 @@ void IndexShardsTemplate<IndexT>::search(
312
240
  all_distances.data() + no * k * n,
313
241
  all_labels.data() + no * k * n);
314
242
 
243
+ translate_labels(
244
+ n * k, all_labels.data() + no * k * n, translations[no]);
245
+
315
246
  if (index->verbose) {
316
247
  printf("end query shard %d\n", no);
317
248
  }
@@ -319,38 +250,24 @@ void IndexShardsTemplate<IndexT>::search(
319
250
 
320
251
  this->runOnIndex(fn);
321
252
 
322
- std::vector<long> translations(nshard, 0);
323
-
324
- // Because we just called runOnIndex above, it is safe to access the
325
- // sub-index ntotal here
326
- if (successive_ids) {
327
- translations[0] = 0;
328
-
329
- for (int s = 0; s + 1 < nshard; s++) {
330
- translations[s + 1] = translations[s] + this->at(s)->ntotal;
331
- }
332
- }
333
-
334
253
  if (this->metric_type == METRIC_L2) {
335
- merge_tables<IndexT, CMin<distance_t, int>>(
254
+ merge_knn_results<idx_t, CMin<distance_t, int>>(
336
255
  n,
337
256
  k,
338
257
  nshard,
258
+ all_distances.data(),
259
+ all_labels.data(),
339
260
  distances,
340
- labels,
341
- all_distances,
342
- all_labels,
343
- translations);
261
+ labels);
344
262
  } else {
345
- merge_tables<IndexT, CMax<distance_t, int>>(
263
+ merge_knn_results<idx_t, CMax<distance_t, int>>(
346
264
  n,
347
265
  k,
348
266
  nshard,
267
+ all_distances.data(),
268
+ all_labels.data(),
349
269
  distances,
350
- labels,
351
- all_distances,
352
- all_labels,
353
- translations);
270
+ labels);
354
271
  }
355
272
  }
356
273
 
@@ -18,7 +18,6 @@ namespace faiss {
18
18
  */
19
19
  template <typename IndexT>
20
20
  struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
21
- using idx_t = typename IndexT::idx_t;
22
21
  using component_t = typename IndexT::component_t;
23
22
  using distance_t = typename IndexT::distance_t;
24
23
 
@@ -72,7 +71,7 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
72
71
  * Cases (successive_ids, xids):
73
72
  * - true, non-NULL ERROR: it makes no sense to pass in ids and
74
73
  * request them to be shifted
75
- * - true, NULL OK, but should be called only once (calls add()
74
+ * - true, NULL OK: but should be called only once (calls add()
76
75
  * on sub-indexes).
77
76
  * - false, non-NULL OK: will call add_with_ids with passed in xids
78
77
  * distributed evenly over shards
@@ -96,7 +95,7 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
96
95
 
97
96
  /// Synchronize the top-level index (IndexShards) with data in the
98
97
  /// sub-indices
99
- void syncWithSubIndexes();
98
+ virtual void syncWithSubIndexes();
100
99
 
101
100
  protected:
102
101
  /// Called just after an index is added
@@ -0,0 +1,246 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexShardsIVF.h>
11
+
12
+ #include <cinttypes>
13
+ #include <cstdio>
14
+ #include <functional>
15
+
16
+ #include <faiss/impl/FaissAssert.h>
17
+ #include <faiss/utils/Heap.h>
18
+ #include <faiss/utils/WorkerThread.h>
19
+ #include <faiss/utils/utils.h>
20
+
21
+ namespace faiss {
22
+
23
+ // subroutines
24
+ namespace {
25
+
26
+ // add translation to all valid labels
27
+ void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
28
+ if (translation == 0) {
29
+ return;
30
+ }
31
+ for (int64_t i = 0; i < n; i++) {
32
+ if (labels[i] < 0) {
33
+ continue;
34
+ }
35
+ labels[i] += translation;
36
+ }
37
+ }
38
+
39
+ } // anonymous namespace
40
+
41
+ /************************************************************
42
+ * IndexShardsIVF
43
+ ************************************************************/
44
+
45
+ IndexShardsIVF::IndexShardsIVF(
46
+ Index* quantizer,
47
+ size_t nlist,
48
+ bool threaded,
49
+ bool successive_ids)
50
+ : IndexShardsTemplate<Index>(quantizer->d, threaded, successive_ids),
51
+ Level1Quantizer(quantizer, nlist) {
52
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist;
53
+ }
54
+
55
+ void IndexShardsIVF::addIndex(Index* index) {
56
+ auto index_ivf = dynamic_cast<IndexIVFInterface*>(index);
57
+ FAISS_THROW_IF_NOT_MSG(index_ivf, "can only add IndexIVFs");
58
+ FAISS_THROW_IF_NOT(index_ivf->nlist == nlist);
59
+ IndexShardsTemplate<Index>::addIndex(index);
60
+ }
61
+
62
+ void IndexShardsIVF::train(idx_t n, const component_t* x) {
63
+ if (verbose) {
64
+ printf("Training level-1 quantizer\n");
65
+ }
66
+ train_q1(n, x, verbose, metric_type);
67
+
68
+ // set the sub-quantizer codebooks
69
+ std::vector<float> centroids(nlist * d);
70
+ quantizer->reconstruct_n(0, nlist, centroids.data());
71
+
72
+ // probably not worth running in parallel
73
+ for (size_t i = 0; i < indices_.size(); i++) {
74
+ Index* index = indices_[i].first;
75
+ auto index_ivf = dynamic_cast<IndexIVFInterface*>(index);
76
+ Index* quantizer = index_ivf->quantizer;
77
+ if (!quantizer->is_trained) {
78
+ quantizer->train(nlist, centroids.data());
79
+ }
80
+ quantizer->add(nlist, centroids.data());
81
+ // finish training
82
+ index->train(n, x);
83
+ }
84
+
85
+ is_trained = true;
86
+ }
87
+
88
+ void IndexShardsIVF::add_with_ids(
89
+ idx_t n,
90
+ const component_t* x,
91
+ const idx_t* xids) {
92
+ // IndexIVF exposes add_core that we can use to factorize the
93
+ bool all_index_ivf = true;
94
+ for (size_t i = 0; i < indices_.size(); i++) {
95
+ Index* index = indices_[i].first;
96
+ all_index_ivf = all_index_ivf && dynamic_cast<IndexIVF*>(index);
97
+ }
98
+ if (!all_index_ivf) {
99
+ IndexShardsTemplate<Index>::add_with_ids(n, x, xids);
100
+ return;
101
+ }
102
+ FAISS_THROW_IF_NOT_MSG(
103
+ !(successive_ids && xids),
104
+ "It makes no sense to pass in ids and "
105
+ "request them to be shifted");
106
+
107
+ if (successive_ids) {
108
+ FAISS_THROW_IF_NOT_MSG(
109
+ !xids,
110
+ "It makes no sense to pass in ids and "
111
+ "request them to be shifted");
112
+ FAISS_THROW_IF_NOT_MSG(
113
+ this->ntotal == 0,
114
+ "when adding to IndexShards with sucessive_ids, "
115
+ "only add() in a single pass is supported");
116
+ }
117
+
118
+ // perform coarse quantization
119
+ std::vector<idx_t> Iq(n);
120
+ std::vector<float> Dq(n);
121
+ quantizer->search(n, x, 1, Dq.data(), Iq.data());
122
+
123
+ // possibly shift ids
124
+ idx_t nshard = this->count();
125
+ const idx_t* ids = xids;
126
+ std::vector<idx_t> aids;
127
+ if (!ids && !successive_ids) {
128
+ aids.resize(n);
129
+
130
+ for (idx_t i = 0; i < n; i++) {
131
+ aids[i] = this->ntotal + i;
132
+ }
133
+ ids = aids.data();
134
+ }
135
+ idx_t d = this->d;
136
+
137
+ auto fn = [n, ids, x, nshard, d, Iq](int no, Index* index) {
138
+ idx_t i0 = (idx_t)no * n / nshard;
139
+ idx_t i1 = ((idx_t)no + 1) * n / nshard;
140
+ const float* x0 = x + i0 * d;
141
+ auto index_ivf = dynamic_cast<IndexIVF*>(index);
142
+
143
+ if (index->verbose) {
144
+ printf("begin add shard %d on %" PRId64 " points\n", no, n);
145
+ }
146
+
147
+ index_ivf->add_core(
148
+ i1 - i0, x + i0 * d, ids ? ids + i0 : nullptr, Iq.data() + i0);
149
+
150
+ if (index->verbose) {
151
+ printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
152
+ }
153
+ };
154
+
155
+ this->runOnIndex(fn);
156
+ syncWithSubIndexes();
157
+ }
158
+
159
+ void IndexShardsIVF::search(
160
+ idx_t n,
161
+ const component_t* x,
162
+ idx_t k,
163
+ distance_t* distances,
164
+ idx_t* labels,
165
+ const SearchParameters* params_in) const {
166
+ FAISS_THROW_IF_NOT(k > 0);
167
+ FAISS_THROW_IF_NOT(count() > 0);
168
+ const IVFSearchParameters* params = nullptr;
169
+ if (params_in) {
170
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
171
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
172
+ }
173
+
174
+ auto index0 = dynamic_cast<const IndexIVFInterface*>(at(0));
175
+ idx_t nprobe = params ? params->nprobe : index0->nprobe;
176
+
177
+ // coarse quantization (TODO: support tiling with search_precomputed)
178
+ std::vector<distance_t> Dq(n * nprobe);
179
+ std::vector<idx_t> Iq(n * nprobe);
180
+
181
+ quantizer->search(n, x, nprobe, Dq.data(), Iq.data());
182
+
183
+ int64_t nshard = this->count();
184
+
185
+ std::vector<distance_t> all_distances(nshard * k * n);
186
+ std::vector<idx_t> all_labels(nshard * k * n);
187
+ std::vector<int64_t> translations(nshard, 0);
188
+
189
+ if (successive_ids) {
190
+ translations[0] = 0;
191
+ for (int s = 0; s + 1 < nshard; s++) {
192
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
193
+ }
194
+ }
195
+
196
+ auto fn = [&](int no, const Index* indexIn) {
197
+ if (indexIn->verbose) {
198
+ printf("begin query shard %d on %" PRId64 " points\n", no, n);
199
+ }
200
+
201
+ auto index = dynamic_cast<const IndexIVFInterface*>(indexIn);
202
+
203
+ FAISS_THROW_IF_NOT_MSG(index->nprobe == nprobe, "inconsistent nprobe");
204
+
205
+ index->search_preassigned(
206
+ n,
207
+ x,
208
+ k,
209
+ Iq.data(),
210
+ Dq.data(),
211
+ all_distances.data() + no * k * n,
212
+ all_labels.data() + no * k * n,
213
+ false);
214
+
215
+ translate_labels(
216
+ n * k, all_labels.data() + no * k * n, translations[no]);
217
+
218
+ if (indexIn->verbose) {
219
+ printf("end query shard %d\n", no);
220
+ }
221
+ };
222
+
223
+ this->runOnIndex(fn);
224
+
225
+ if (this->metric_type == METRIC_L2) {
226
+ merge_knn_results<idx_t, CMin<distance_t, int>>(
227
+ n,
228
+ k,
229
+ nshard,
230
+ all_distances.data(),
231
+ all_labels.data(),
232
+ distances,
233
+ labels);
234
+ } else {
235
+ merge_knn_results<idx_t, CMax<distance_t, int>>(
236
+ n,
237
+ k,
238
+ nshard,
239
+ all_distances.data(),
240
+ all_labels.data(),
241
+ distances,
242
+ labels);
243
+ }
244
+ }
245
+
246
+ } // namespace faiss
@@ -0,0 +1,42 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/IndexIVF.h>
11
+ #include <faiss/IndexShards.h>
12
+
13
+ namespace faiss {
14
+
15
+ /**
16
+ * IndexShards with a common coarse quantizer. All the indexes added should be
17
+ * IndexIVFInterface indexes so that the search_precomputed can be called.
18
+ */
19
+ struct IndexShardsIVF : public IndexShards, Level1Quantizer {
20
+ explicit IndexShardsIVF(
21
+ Index* quantizer,
22
+ size_t nlist,
23
+ bool threaded = false,
24
+ bool successive_ids = true);
25
+
26
+ void addIndex(Index* index) override;
27
+
28
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
29
+ override;
30
+
31
+ void train(idx_t n, const component_t* x) override;
32
+
33
+ void search(
34
+ idx_t n,
35
+ const component_t* x,
36
+ idx_t k,
37
+ distance_t* distances,
38
+ idx_t* labels,
39
+ const SearchParameters* params = nullptr) const override;
40
+ };
41
+
42
+ } // namespace faiss
@@ -19,6 +19,8 @@
19
19
  #include <faiss/impl/IDSelector.h>
20
20
  #include <faiss/utils/Heap.h>
21
21
  #include <faiss/utils/WorkerThread.h>
22
+ #include <faiss/utils/random.h>
23
+ #include <faiss/utils/utils.h>
22
24
 
23
25
  namespace faiss {
24
26
 
@@ -154,4 +156,88 @@ IndexSplitVectors::~IndexSplitVectors() {
154
156
  }
155
157
  }
156
158
 
159
+ /********************************************************
160
+ * IndexRandom implementation
161
+ */
162
+
163
+ IndexRandom::IndexRandom(
164
+ idx_t d,
165
+ idx_t ntotal,
166
+ int64_t seed,
167
+ MetricType metric_type)
168
+ : Index(d, metric_type), seed(seed) {
169
+ this->ntotal = ntotal;
170
+ is_trained = true;
171
+ }
172
+
173
+ void IndexRandom::add(idx_t n, const float*) {
174
+ ntotal += n;
175
+ }
176
+
177
+ void IndexRandom::search(
178
+ idx_t n,
179
+ const float* x,
180
+ idx_t k,
181
+ float* distances,
182
+ idx_t* labels,
183
+ const SearchParameters* params) const {
184
+ FAISS_THROW_IF_NOT_MSG(
185
+ !params, "search params not supported for this index");
186
+ FAISS_THROW_IF_NOT(k <= ntotal);
187
+ #pragma omp parallel for if (n > 1000)
188
+ for (idx_t i = 0; i < n; i++) {
189
+ RandomGenerator rng(
190
+ seed + ivec_checksum(d, (const int32_t*)(x + i * d)));
191
+ idx_t* I = labels + i * k;
192
+ float* D = distances + i * k;
193
+ // assumes k << ntotal
194
+ if (k < 100 * ntotal) {
195
+ std::unordered_set<idx_t> map;
196
+ for (int j = 0; j < k; j++) {
197
+ idx_t ii;
198
+ for (;;) {
199
+ // yes I know it's not strictly uniform...
200
+ ii = rng.rand_int64() % ntotal;
201
+ if (map.count(ii) == 0) {
202
+ break;
203
+ }
204
+ }
205
+ I[j] = ii;
206
+ map.insert(ii);
207
+ }
208
+ } else {
209
+ std::vector<idx_t> perm(ntotal);
210
+ for (idx_t j = 0; j < ntotal; j++) {
211
+ perm[j] = j;
212
+ }
213
+ for (int j = 0; j < k; j++) {
214
+ std::swap(perm[j], perm[rng.rand_int(ntotal)]);
215
+ I[j] = perm[j];
216
+ }
217
+ }
218
+ float dprev = 0;
219
+ for (int j = 0; j < k; j++) {
220
+ float step = rng.rand_float();
221
+ if (is_similarity_metric(metric_type)) {
222
+ step = -step;
223
+ }
224
+ dprev += step;
225
+ D[j] = dprev;
226
+ }
227
+ }
228
+ }
229
+
230
+ void IndexRandom::reconstruct(idx_t key, float* recons) const {
231
+ RandomGenerator rng(seed + 123332 + key);
232
+ for (size_t i = 0; i < d; i++) {
233
+ recons[i] = rng.rand_float();
234
+ }
235
+ }
236
+
237
+ void IndexRandom::reset() {
238
+ ntotal = 0;
239
+ }
240
+
241
+ IndexRandom::~IndexRandom() {}
242
+
157
243
  } // namespace faiss
@@ -49,6 +49,35 @@ struct IndexSplitVectors : Index {
49
49
  ~IndexSplitVectors() override;
50
50
  };
51
51
 
52
+ /** index that returns random results.
53
+ * used mainly for time benchmarks
54
+ */
55
+ struct IndexRandom : Index {
56
+ int64_t seed;
57
+
58
+ explicit IndexRandom(
59
+ idx_t d,
60
+ idx_t ntotal = 0,
61
+ int64_t seed = 1234,
62
+ MetricType mt = METRIC_L2);
63
+
64
+ void add(idx_t n, const float* x) override;
65
+
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels,
72
+ const SearchParameters* params = nullptr) const override;
73
+
74
+ void reconstruct(idx_t key, float* recons) const override;
75
+
76
+ void reset() override;
77
+
78
+ ~IndexRandom() override;
79
+ };
80
+
52
81
  } // namespace faiss
53
82
 
54
83
  #endif
@@ -10,6 +10,8 @@
10
10
  #ifndef FAISS_METRIC_TYPE_H
11
11
  #define FAISS_METRIC_TYPE_H
12
12
 
13
+ #include <faiss/impl/platform_macros.h>
14
+
13
15
  namespace faiss {
14
16
 
15
17
  /// The metric space for vector comparison for Faiss indices and algorithms.
@@ -29,8 +31,20 @@ enum MetricType {
29
31
  METRIC_Canberra = 20,
30
32
  METRIC_BrayCurtis,
31
33
  METRIC_JensenShannon,
34
+ METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
35
+ ///< where a_i, b_i > 0
32
36
  };
33
37
 
38
+ /// all vector indices are this type
39
+ using idx_t = int64_t;
40
+
41
+ /// this function is used to distinguish between min and max indexes since
42
+ /// we need to support similarity and dis-similarity metrics in a flexible way
43
+ constexpr bool is_similarity_metric(MetricType metric_type) {
44
+ return ((metric_type == METRIC_INNER_PRODUCT) ||
45
+ (metric_type == METRIC_Jaccard));
46
+ }
47
+
34
48
  } // namespace faiss
35
49
 
36
50
  #endif