faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -1,3 +1,10 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
1
8
  #pragma once
2
9
 
3
10
  #include <cstdint>
@@ -22,98 +22,17 @@ namespace faiss {
22
22
  // subroutines
23
23
  namespace {
24
24
 
25
- typedef Index::idx_t idx_t;
26
-
27
25
  // add translation to all valid labels
28
- void translate_labels(long n, idx_t* labels, long translation) {
26
+ void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
29
27
  if (translation == 0)
30
28
  return;
31
- for (long i = 0; i < n; i++) {
29
+ for (int64_t i = 0; i < n; i++) {
32
30
  if (labels[i] < 0)
33
31
  continue;
34
32
  labels[i] += translation;
35
33
  }
36
34
  }
37
35
 
38
- /** merge result tables from several shards.
39
- * @param all_distances size nshard * n * k
40
- * @param all_labels idem
41
- * @param translartions label translations to apply, size nshard
42
- */
43
-
44
- template <class IndexClass, class C>
45
- void merge_tables(
46
- long n,
47
- long k,
48
- long nshard,
49
- typename IndexClass::distance_t* distances,
50
- idx_t* labels,
51
- const std::vector<typename IndexClass::distance_t>& all_distances,
52
- const std::vector<idx_t>& all_labels,
53
- const std::vector<long>& translations) {
54
- if (k == 0) {
55
- return;
56
- }
57
- using distance_t = typename IndexClass::distance_t;
58
-
59
- long stride = n * k;
60
- #pragma omp parallel
61
- {
62
- std::vector<int> buf(2 * nshard);
63
- int* pointer = buf.data();
64
- int* shard_ids = pointer + nshard;
65
- std::vector<distance_t> buf2(nshard);
66
- distance_t* heap_vals = buf2.data();
67
- #pragma omp for
68
- for (long i = 0; i < n; i++) {
69
- // the heap maps values to the shard where they are
70
- // produced.
71
- const distance_t* D_in = all_distances.data() + i * k;
72
- const idx_t* I_in = all_labels.data() + i * k;
73
- int heap_size = 0;
74
-
75
- for (long s = 0; s < nshard; s++) {
76
- pointer[s] = 0;
77
- if (I_in[stride * s] >= 0) {
78
- heap_push<C>(
79
- ++heap_size,
80
- heap_vals,
81
- shard_ids,
82
- D_in[stride * s],
83
- s);
84
- }
85
- }
86
-
87
- distance_t* D = distances + i * k;
88
- idx_t* I = labels + i * k;
89
-
90
- for (int j = 0; j < k; j++) {
91
- if (heap_size == 0) {
92
- I[j] = -1;
93
- D[j] = C::neutral();
94
- } else {
95
- // pop best element
96
- int s = shard_ids[0];
97
- int& p = pointer[s];
98
- D[j] = heap_vals[0];
99
- I[j] = I_in[stride * s + p] + translations[s];
100
-
101
- heap_pop<C>(heap_size--, heap_vals, shard_ids);
102
- p++;
103
- if (p < k && I_in[stride * s + p] >= 0) {
104
- heap_push<C>(
105
- ++heap_size,
106
- heap_vals,
107
- shard_ids,
108
- D_in[stride * s + p],
109
- s);
110
- }
111
- }
112
- }
113
- }
114
- }
115
- }
116
-
117
36
  } // anonymous namespace
118
37
 
119
38
  template <typename IndexT>
@@ -247,11 +166,9 @@ void IndexShardsTemplate<IndexT>::add_with_ids(
247
166
 
248
167
  if (!ids && !successive_ids) {
249
168
  aids.resize(n);
250
-
251
169
  for (idx_t i = 0; i < n; i++) {
252
170
  aids[i] = this->ntotal + i;
253
171
  }
254
-
255
172
  ids = aids.data();
256
173
  }
257
174
 
@@ -294,12 +211,23 @@ void IndexShardsTemplate<IndexT>::search(
294
211
  !params, "search params not supported for this index");
295
212
  FAISS_THROW_IF_NOT(k > 0);
296
213
 
297
- long nshard = this->count();
214
+ int64_t nshard = this->count();
298
215
 
299
216
  std::vector<distance_t> all_distances(nshard * k * n);
300
217
  std::vector<idx_t> all_labels(nshard * k * n);
218
+ std::vector<int64_t> translations(nshard, 0);
219
+
220
+ // Because we just called runOnIndex above, it is safe to access the
221
+ // sub-index ntotal here
222
+ if (successive_ids) {
223
+ translations[0] = 0;
301
224
 
302
- auto fn = [n, k, x, &all_distances, &all_labels](
225
+ for (int s = 0; s + 1 < nshard; s++) {
226
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
227
+ }
228
+ }
229
+
230
+ auto fn = [n, k, x, &all_distances, &all_labels, &translations](
303
231
  int no, const IndexT* index) {
304
232
  if (index->verbose) {
305
233
  printf("begin query shard %d on %" PRId64 " points\n", no, n);
@@ -312,6 +240,9 @@ void IndexShardsTemplate<IndexT>::search(
312
240
  all_distances.data() + no * k * n,
313
241
  all_labels.data() + no * k * n);
314
242
 
243
+ translate_labels(
244
+ n * k, all_labels.data() + no * k * n, translations[no]);
245
+
315
246
  if (index->verbose) {
316
247
  printf("end query shard %d\n", no);
317
248
  }
@@ -319,38 +250,24 @@ void IndexShardsTemplate<IndexT>::search(
319
250
 
320
251
  this->runOnIndex(fn);
321
252
 
322
- std::vector<long> translations(nshard, 0);
323
-
324
- // Because we just called runOnIndex above, it is safe to access the
325
- // sub-index ntotal here
326
- if (successive_ids) {
327
- translations[0] = 0;
328
-
329
- for (int s = 0; s + 1 < nshard; s++) {
330
- translations[s + 1] = translations[s] + this->at(s)->ntotal;
331
- }
332
- }
333
-
334
253
  if (this->metric_type == METRIC_L2) {
335
- merge_tables<IndexT, CMin<distance_t, int>>(
254
+ merge_knn_results<idx_t, CMin<distance_t, int>>(
336
255
  n,
337
256
  k,
338
257
  nshard,
258
+ all_distances.data(),
259
+ all_labels.data(),
339
260
  distances,
340
- labels,
341
- all_distances,
342
- all_labels,
343
- translations);
261
+ labels);
344
262
  } else {
345
- merge_tables<IndexT, CMax<distance_t, int>>(
263
+ merge_knn_results<idx_t, CMax<distance_t, int>>(
346
264
  n,
347
265
  k,
348
266
  nshard,
267
+ all_distances.data(),
268
+ all_labels.data(),
349
269
  distances,
350
- labels,
351
- all_distances,
352
- all_labels,
353
- translations);
270
+ labels);
354
271
  }
355
272
  }
356
273
 
@@ -18,7 +18,6 @@ namespace faiss {
18
18
  */
19
19
  template <typename IndexT>
20
20
  struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
21
- using idx_t = typename IndexT::idx_t;
22
21
  using component_t = typename IndexT::component_t;
23
22
  using distance_t = typename IndexT::distance_t;
24
23
 
@@ -72,7 +71,7 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
72
71
  * Cases (successive_ids, xids):
73
72
  * - true, non-NULL ERROR: it makes no sense to pass in ids and
74
73
  * request them to be shifted
75
- * - true, NULL OK, but should be called only once (calls add()
74
+ * - true, NULL OK: but should be called only once (calls add()
76
75
  * on sub-indexes).
77
76
  * - false, non-NULL OK: will call add_with_ids with passed in xids
78
77
  * distributed evenly over shards
@@ -96,7 +95,7 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
96
95
 
97
96
  /// Synchronize the top-level index (IndexShards) with data in the
98
97
  /// sub-indices
99
- void syncWithSubIndexes();
98
+ virtual void syncWithSubIndexes();
100
99
 
101
100
  protected:
102
101
  /// Called just after an index is added
@@ -0,0 +1,246 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexShardsIVF.h>
11
+
12
+ #include <cinttypes>
13
+ #include <cstdio>
14
+ #include <functional>
15
+
16
+ #include <faiss/impl/FaissAssert.h>
17
+ #include <faiss/utils/Heap.h>
18
+ #include <faiss/utils/WorkerThread.h>
19
+ #include <faiss/utils/utils.h>
20
+
21
+ namespace faiss {
22
+
23
+ // subroutines
24
+ namespace {
25
+
26
+ // add translation to all valid labels
27
+ void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
28
+ if (translation == 0) {
29
+ return;
30
+ }
31
+ for (int64_t i = 0; i < n; i++) {
32
+ if (labels[i] < 0) {
33
+ continue;
34
+ }
35
+ labels[i] += translation;
36
+ }
37
+ }
38
+
39
+ } // anonymous namespace
40
+
41
+ /************************************************************
42
+ * IndexShardsIVF
43
+ ************************************************************/
44
+
45
+ IndexShardsIVF::IndexShardsIVF(
46
+ Index* quantizer,
47
+ size_t nlist,
48
+ bool threaded,
49
+ bool successive_ids)
50
+ : IndexShardsTemplate<Index>(quantizer->d, threaded, successive_ids),
51
+ Level1Quantizer(quantizer, nlist) {
52
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist;
53
+ }
54
+
55
+ void IndexShardsIVF::addIndex(Index* index) {
56
+ auto index_ivf = dynamic_cast<IndexIVFInterface*>(index);
57
+ FAISS_THROW_IF_NOT_MSG(index_ivf, "can only add IndexIVFs");
58
+ FAISS_THROW_IF_NOT(index_ivf->nlist == nlist);
59
+ IndexShardsTemplate<Index>::addIndex(index);
60
+ }
61
+
62
+ void IndexShardsIVF::train(idx_t n, const component_t* x) {
63
+ if (verbose) {
64
+ printf("Training level-1 quantizer\n");
65
+ }
66
+ train_q1(n, x, verbose, metric_type);
67
+
68
+ // set the sub-quantizer codebooks
69
+ std::vector<float> centroids(nlist * d);
70
+ quantizer->reconstruct_n(0, nlist, centroids.data());
71
+
72
+ // probably not worth running in parallel
73
+ for (size_t i = 0; i < indices_.size(); i++) {
74
+ Index* index = indices_[i].first;
75
+ auto index_ivf = dynamic_cast<IndexIVFInterface*>(index);
76
+ Index* quantizer = index_ivf->quantizer;
77
+ if (!quantizer->is_trained) {
78
+ quantizer->train(nlist, centroids.data());
79
+ }
80
+ quantizer->add(nlist, centroids.data());
81
+ // finish training
82
+ index->train(n, x);
83
+ }
84
+
85
+ is_trained = true;
86
+ }
87
+
88
+ void IndexShardsIVF::add_with_ids(
89
+ idx_t n,
90
+ const component_t* x,
91
+ const idx_t* xids) {
92
+ // IndexIVF exposes add_core that we can use to factorize the
93
+ bool all_index_ivf = true;
94
+ for (size_t i = 0; i < indices_.size(); i++) {
95
+ Index* index = indices_[i].first;
96
+ all_index_ivf = all_index_ivf && dynamic_cast<IndexIVF*>(index);
97
+ }
98
+ if (!all_index_ivf) {
99
+ IndexShardsTemplate<Index>::add_with_ids(n, x, xids);
100
+ return;
101
+ }
102
+ FAISS_THROW_IF_NOT_MSG(
103
+ !(successive_ids && xids),
104
+ "It makes no sense to pass in ids and "
105
+ "request them to be shifted");
106
+
107
+ if (successive_ids) {
108
+ FAISS_THROW_IF_NOT_MSG(
109
+ !xids,
110
+ "It makes no sense to pass in ids and "
111
+ "request them to be shifted");
112
+ FAISS_THROW_IF_NOT_MSG(
113
+ this->ntotal == 0,
114
+ "when adding to IndexShards with sucessive_ids, "
115
+ "only add() in a single pass is supported");
116
+ }
117
+
118
+ // perform coarse quantization
119
+ std::vector<idx_t> Iq(n);
120
+ std::vector<float> Dq(n);
121
+ quantizer->search(n, x, 1, Dq.data(), Iq.data());
122
+
123
+ // possibly shift ids
124
+ idx_t nshard = this->count();
125
+ const idx_t* ids = xids;
126
+ std::vector<idx_t> aids;
127
+ if (!ids && !successive_ids) {
128
+ aids.resize(n);
129
+
130
+ for (idx_t i = 0; i < n; i++) {
131
+ aids[i] = this->ntotal + i;
132
+ }
133
+ ids = aids.data();
134
+ }
135
+ idx_t d = this->d;
136
+
137
+ auto fn = [n, ids, x, nshard, d, Iq](int no, Index* index) {
138
+ idx_t i0 = (idx_t)no * n / nshard;
139
+ idx_t i1 = ((idx_t)no + 1) * n / nshard;
140
+ const float* x0 = x + i0 * d;
141
+ auto index_ivf = dynamic_cast<IndexIVF*>(index);
142
+
143
+ if (index->verbose) {
144
+ printf("begin add shard %d on %" PRId64 " points\n", no, n);
145
+ }
146
+
147
+ index_ivf->add_core(
148
+ i1 - i0, x + i0 * d, ids ? ids + i0 : nullptr, Iq.data() + i0);
149
+
150
+ if (index->verbose) {
151
+ printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
152
+ }
153
+ };
154
+
155
+ this->runOnIndex(fn);
156
+ syncWithSubIndexes();
157
+ }
158
+
159
+ void IndexShardsIVF::search(
160
+ idx_t n,
161
+ const component_t* x,
162
+ idx_t k,
163
+ distance_t* distances,
164
+ idx_t* labels,
165
+ const SearchParameters* params_in) const {
166
+ FAISS_THROW_IF_NOT(k > 0);
167
+ FAISS_THROW_IF_NOT(count() > 0);
168
+ const IVFSearchParameters* params = nullptr;
169
+ if (params_in) {
170
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
171
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
172
+ }
173
+
174
+ auto index0 = dynamic_cast<const IndexIVFInterface*>(at(0));
175
+ idx_t nprobe = params ? params->nprobe : index0->nprobe;
176
+
177
+ // coarse quantization (TODO: support tiling with search_precomputed)
178
+ std::vector<distance_t> Dq(n * nprobe);
179
+ std::vector<idx_t> Iq(n * nprobe);
180
+
181
+ quantizer->search(n, x, nprobe, Dq.data(), Iq.data());
182
+
183
+ int64_t nshard = this->count();
184
+
185
+ std::vector<distance_t> all_distances(nshard * k * n);
186
+ std::vector<idx_t> all_labels(nshard * k * n);
187
+ std::vector<int64_t> translations(nshard, 0);
188
+
189
+ if (successive_ids) {
190
+ translations[0] = 0;
191
+ for (int s = 0; s + 1 < nshard; s++) {
192
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
193
+ }
194
+ }
195
+
196
+ auto fn = [&](int no, const Index* indexIn) {
197
+ if (indexIn->verbose) {
198
+ printf("begin query shard %d on %" PRId64 " points\n", no, n);
199
+ }
200
+
201
+ auto index = dynamic_cast<const IndexIVFInterface*>(indexIn);
202
+
203
+ FAISS_THROW_IF_NOT_MSG(index->nprobe == nprobe, "inconsistent nprobe");
204
+
205
+ index->search_preassigned(
206
+ n,
207
+ x,
208
+ k,
209
+ Iq.data(),
210
+ Dq.data(),
211
+ all_distances.data() + no * k * n,
212
+ all_labels.data() + no * k * n,
213
+ false);
214
+
215
+ translate_labels(
216
+ n * k, all_labels.data() + no * k * n, translations[no]);
217
+
218
+ if (indexIn->verbose) {
219
+ printf("end query shard %d\n", no);
220
+ }
221
+ };
222
+
223
+ this->runOnIndex(fn);
224
+
225
+ if (this->metric_type == METRIC_L2) {
226
+ merge_knn_results<idx_t, CMin<distance_t, int>>(
227
+ n,
228
+ k,
229
+ nshard,
230
+ all_distances.data(),
231
+ all_labels.data(),
232
+ distances,
233
+ labels);
234
+ } else {
235
+ merge_knn_results<idx_t, CMax<distance_t, int>>(
236
+ n,
237
+ k,
238
+ nshard,
239
+ all_distances.data(),
240
+ all_labels.data(),
241
+ distances,
242
+ labels);
243
+ }
244
+ }
245
+
246
+ } // namespace faiss
@@ -0,0 +1,42 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/IndexIVF.h>
11
+ #include <faiss/IndexShards.h>
12
+
13
+ namespace faiss {
14
+
15
+ /**
16
+ * IndexShards with a common coarse quantizer. All the indexes added should be
17
+ * IndexIVFInterface indexes so that the search_precomputed can be called.
18
+ */
19
+ struct IndexShardsIVF : public IndexShards, Level1Quantizer {
20
+ explicit IndexShardsIVF(
21
+ Index* quantizer,
22
+ size_t nlist,
23
+ bool threaded = false,
24
+ bool successive_ids = true);
25
+
26
+ void addIndex(Index* index) override;
27
+
28
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
29
+ override;
30
+
31
+ void train(idx_t n, const component_t* x) override;
32
+
33
+ void search(
34
+ idx_t n,
35
+ const component_t* x,
36
+ idx_t k,
37
+ distance_t* distances,
38
+ idx_t* labels,
39
+ const SearchParameters* params = nullptr) const override;
40
+ };
41
+
42
+ } // namespace faiss
@@ -19,6 +19,8 @@
19
19
  #include <faiss/impl/IDSelector.h>
20
20
  #include <faiss/utils/Heap.h>
21
21
  #include <faiss/utils/WorkerThread.h>
22
+ #include <faiss/utils/random.h>
23
+ #include <faiss/utils/utils.h>
22
24
 
23
25
  namespace faiss {
24
26
 
@@ -154,4 +156,88 @@ IndexSplitVectors::~IndexSplitVectors() {
154
156
  }
155
157
  }
156
158
 
159
+ /********************************************************
160
+ * IndexRandom implementation
161
+ */
162
+
163
+ IndexRandom::IndexRandom(
164
+ idx_t d,
165
+ idx_t ntotal,
166
+ int64_t seed,
167
+ MetricType metric_type)
168
+ : Index(d, metric_type), seed(seed) {
169
+ this->ntotal = ntotal;
170
+ is_trained = true;
171
+ }
172
+
173
+ void IndexRandom::add(idx_t n, const float*) {
174
+ ntotal += n;
175
+ }
176
+
177
+ void IndexRandom::search(
178
+ idx_t n,
179
+ const float* x,
180
+ idx_t k,
181
+ float* distances,
182
+ idx_t* labels,
183
+ const SearchParameters* params) const {
184
+ FAISS_THROW_IF_NOT_MSG(
185
+ !params, "search params not supported for this index");
186
+ FAISS_THROW_IF_NOT(k <= ntotal);
187
+ #pragma omp parallel for if (n > 1000)
188
+ for (idx_t i = 0; i < n; i++) {
189
+ RandomGenerator rng(
190
+ seed + ivec_checksum(d, (const int32_t*)(x + i * d)));
191
+ idx_t* I = labels + i * k;
192
+ float* D = distances + i * k;
193
+ // assumes k << ntotal
194
+ if (k < 100 * ntotal) {
195
+ std::unordered_set<idx_t> map;
196
+ for (int j = 0; j < k; j++) {
197
+ idx_t ii;
198
+ for (;;) {
199
+ // yes I know it's not strictly uniform...
200
+ ii = rng.rand_int64() % ntotal;
201
+ if (map.count(ii) == 0) {
202
+ break;
203
+ }
204
+ }
205
+ I[j] = ii;
206
+ map.insert(ii);
207
+ }
208
+ } else {
209
+ std::vector<idx_t> perm(ntotal);
210
+ for (idx_t j = 0; j < ntotal; j++) {
211
+ perm[j] = j;
212
+ }
213
+ for (int j = 0; j < k; j++) {
214
+ std::swap(perm[j], perm[rng.rand_int(ntotal)]);
215
+ I[j] = perm[j];
216
+ }
217
+ }
218
+ float dprev = 0;
219
+ for (int j = 0; j < k; j++) {
220
+ float step = rng.rand_float();
221
+ if (is_similarity_metric(metric_type)) {
222
+ step = -step;
223
+ }
224
+ dprev += step;
225
+ D[j] = dprev;
226
+ }
227
+ }
228
+ }
229
+
230
+ void IndexRandom::reconstruct(idx_t key, float* recons) const {
231
+ RandomGenerator rng(seed + 123332 + key);
232
+ for (size_t i = 0; i < d; i++) {
233
+ recons[i] = rng.rand_float();
234
+ }
235
+ }
236
+
237
+ void IndexRandom::reset() {
238
+ ntotal = 0;
239
+ }
240
+
241
+ IndexRandom::~IndexRandom() {}
242
+
157
243
  } // namespace faiss
@@ -49,6 +49,35 @@ struct IndexSplitVectors : Index {
49
49
  ~IndexSplitVectors() override;
50
50
  };
51
51
 
52
+ /** index that returns random results.
53
+ * used mainly for time benchmarks
54
+ */
55
+ struct IndexRandom : Index {
56
+ int64_t seed;
57
+
58
+ explicit IndexRandom(
59
+ idx_t d,
60
+ idx_t ntotal = 0,
61
+ int64_t seed = 1234,
62
+ MetricType mt = METRIC_L2);
63
+
64
+ void add(idx_t n, const float* x) override;
65
+
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels,
72
+ const SearchParameters* params = nullptr) const override;
73
+
74
+ void reconstruct(idx_t key, float* recons) const override;
75
+
76
+ void reset() override;
77
+
78
+ ~IndexRandom() override;
79
+ };
80
+
52
81
  } // namespace faiss
53
82
 
54
83
  #endif
@@ -10,6 +10,8 @@
10
10
  #ifndef FAISS_METRIC_TYPE_H
11
11
  #define FAISS_METRIC_TYPE_H
12
12
 
13
+ #include <faiss/impl/platform_macros.h>
14
+
13
15
  namespace faiss {
14
16
 
15
17
  /// The metric space for vector comparison for Faiss indices and algorithms.
@@ -29,8 +31,20 @@ enum MetricType {
29
31
  METRIC_Canberra = 20,
30
32
  METRIC_BrayCurtis,
31
33
  METRIC_JensenShannon,
34
+ METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
35
+ ///< where a_i, b_i > 0
32
36
  };
33
37
 
38
+ /// all vector indices are this type
39
+ using idx_t = int64_t;
40
+
41
+ /// this function is used to distinguish between min and max indexes since
42
+ /// we need to support similarity and dis-similarity metrics in a flexible way
43
+ constexpr bool is_similarity_metric(MetricType metric_type) {
44
+ return ((metric_type == METRIC_INNER_PRODUCT) ||
45
+ (metric_type == METRIC_Jaccard));
46
+ }
47
+
34
48
  } // namespace faiss
35
49
 
36
50
  #endif