faiss 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e41b15bbcda6c4d2a250df5b98d86e9baf51b34b90fc2fccb6f0a37f486ef417
4
- data.tar.gz: 768074275062ed45f1752e3a5c9d55a9695a6aa453e925aa0a6e607ce3215bab
3
+ metadata.gz: bdce4ec4f4169dff5f08ccbed2de2750dfd33738fe60d747645f7aaa43187505
4
+ data.tar.gz: a8ab702eead45525bb4aae8b28b9c20bc0d0d8c774a79ef942a9c8d7a9cabc2f
5
5
  SHA512:
6
- metadata.gz: cecc466dd24e03206219b63e750e48b554355c1c5dfc8e911879988a6f31eb628617133f5b584b3de29efcbe65d087cf5b4e219371cee959e8248c989a4dbffc
7
- data.tar.gz: 3e0c6be53825949f9c51a0195d85cbed87bc198dd06852c88c537b13e6bcc8e7fa65a3e3c88667eefef44e95278fe2c73ece89d5f92bd24f8c0d27b543488b56
6
+ metadata.gz: 7e8291961c8a8550e745c55eef5011ca23fc6f5ce7452eeb6da45ebfd020f7c07df70a0a5d7c281e2449214d5ec26102f9194f1aa49d0b9be21304dad3a98368
7
+ data.tar.gz: 80b475d06b237902b88025dc2602a7e7c8ad15ec757cd43d63d143423eb7a1bd759b8c30715b9ec30c2ae3cfecd2eea502e9814524219d396c71067f0959b62e
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.3.2 (2024-10-05)
2
+
3
+ - Updated Faiss to 1.9.0
4
+
1
5
  ## 0.3.1 (2024-03-13)
2
6
 
3
7
  - Updated Faiss to 1.8.0
data/lib/faiss/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Faiss
2
- VERSION = "0.3.1"
2
+ VERSION = "0.3.2"
3
3
  end
@@ -86,7 +86,7 @@ struct OperatingPoint {
86
86
  double perf; ///< performance measure (output of a Criterion)
87
87
  double t; ///< corresponding execution time (ms)
88
88
  std::string key; ///< key that identifies this op pt
89
- int64_t cno; ///< integer identifer
89
+ int64_t cno; ///< integer identifier
90
90
  };
91
91
 
92
92
  struct OperatingPoints {
@@ -11,6 +11,7 @@
11
11
  #include <faiss/VectorTransform.h>
12
12
  #include <faiss/impl/AuxIndexStructures.h>
13
13
 
14
+ #include <chrono>
14
15
  #include <cinttypes>
15
16
  #include <cmath>
16
17
  #include <cstdio>
@@ -74,6 +75,14 @@ void Clustering::train(
74
75
 
75
76
  namespace {
76
77
 
78
+ uint64_t get_actual_rng_seed(const int seed) {
79
+ return (seed >= 0)
80
+ ? seed
81
+ : static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
82
+ .time_since_epoch()
83
+ .count());
84
+ }
85
+
77
86
  idx_t subsample_training_set(
78
87
  const Clustering& clus,
79
88
  idx_t nx,
@@ -87,11 +96,30 @@ idx_t subsample_training_set(
87
96
  clus.k * clus.max_points_per_centroid,
88
97
  nx);
89
98
  }
90
- std::vector<int> perm(nx);
91
- rand_perm(perm.data(), nx, clus.seed);
99
+
100
+ const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
101
+
102
+ std::vector<int> perm;
103
+ if (clus.use_faster_subsampling) {
104
+ // use subsampling with splitmix64 rng
105
+ SplitMix64RandomGenerator rng(actual_seed);
106
+
107
+ const idx_t new_nx = clus.k * clus.max_points_per_centroid;
108
+ perm.resize(new_nx);
109
+ for (idx_t i = 0; i < new_nx; i++) {
110
+ perm[i] = rng.rand_int(nx);
111
+ }
112
+ } else {
113
+ // use subsampling with a default std rng
114
+ perm.resize(nx);
115
+ rand_perm(perm.data(), nx, actual_seed);
116
+ }
117
+
92
118
  nx = clus.k * clus.max_points_per_centroid;
93
119
  uint8_t* x_new = new uint8_t[nx * line_size];
94
120
  *x_out = x_new;
121
+
122
+ // might be worth omp-ing as well
95
123
  for (idx_t i = 0; i < nx; i++) {
96
124
  memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
97
125
  }
@@ -280,7 +308,7 @@ void Clustering::train_encoded(
280
308
 
281
309
  double t0 = getmillisecs();
282
310
 
283
- if (!codec) {
311
+ if (!codec && check_input_data_for_NaNs) {
284
312
  // Check for NaNs in input data. Normally it is the user's
285
313
  // responsibility, but it may spare us some hard-to-debug
286
314
  // reports.
@@ -383,6 +411,9 @@ void Clustering::train_encoded(
383
411
  }
384
412
  t0 = getmillisecs();
385
413
 
414
+ // initialize seed
415
+ const uint64_t actual_seed = get_actual_rng_seed(seed);
416
+
386
417
  // temporary buffer to decode vectors during the optimization
387
418
  std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
388
419
 
@@ -395,7 +426,7 @@ void Clustering::train_encoded(
395
426
  centroids.resize(d * k);
396
427
  std::vector<int> perm(nx);
397
428
 
398
- rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
429
+ rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
399
430
 
400
431
  if (!codec) {
401
432
  for (int i = n_input_centroids; i < k; i++) {
@@ -43,11 +43,20 @@ struct ClusteringParameters {
43
43
  int min_points_per_centroid = 39;
44
44
  /// to limit size of dataset, otherwise the training set is subsampled
45
45
  int max_points_per_centroid = 256;
46
- /// seed for the random number generator
46
+ /// seed for the random number generator.
47
+ /// negative values lead to seeding an internal rng with
48
+ /// std::high_resolution_clock.
47
49
  int seed = 1234;
48
50
 
49
51
  /// when the training set is encoded, batch size of the codec decoder
50
52
  size_t decode_block_size = 32768;
53
+
54
+ /// whether to check for NaNs in an input data
55
+ bool check_input_data_for_NaNs = true;
56
+
57
+ /// Whether to use splitmix64-based random number generator for subsampling,
58
+ /// which is faster, but may pick duplicate points.
59
+ bool use_faster_subsampling = false;
51
60
  };
52
61
 
53
62
  struct ClusteringIterationStats {
@@ -352,7 +352,10 @@ void search_with_parameters(
352
352
  const IndexIVF* index_ivf = dynamic_cast<const IndexIVF*>(index);
353
353
  FAISS_THROW_IF_NOT(index_ivf);
354
354
 
355
- index_ivf->quantizer->search(n, x, params->nprobe, Dq.data(), Iq.data());
355
+ SearchParameters* quantizer_params =
356
+ (params) ? params->quantizer_params : nullptr;
357
+ index_ivf->quantizer->search(
358
+ n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params);
356
359
 
357
360
  if (nb_dis_ptr) {
358
361
  *nb_dis_ptr = count_ndis(index_ivf, n * params->nprobe, Iq.data());
@@ -17,9 +17,21 @@
17
17
  #include <typeinfo>
18
18
 
19
19
  #define FAISS_VERSION_MAJOR 1
20
- #define FAISS_VERSION_MINOR 8
20
+ #define FAISS_VERSION_MINOR 9
21
21
  #define FAISS_VERSION_PATCH 0
22
22
 
23
+ // Macro to combine the version components into a single string
24
+ #ifndef FAISS_STRINGIFY
25
+ #define FAISS_STRINGIFY(ARG) #ARG
26
+ #endif
27
+ #ifndef FAISS_TOSTRING
28
+ #define FAISS_TOSTRING(ARG) FAISS_STRINGIFY(ARG)
29
+ #endif
30
+ #define VERSION_STRING \
31
+ FAISS_TOSTRING(FAISS_VERSION_MAJOR) \
32
+ "." FAISS_TOSTRING(FAISS_VERSION_MINOR) "." FAISS_TOSTRING( \
33
+ FAISS_VERSION_PATCH)
34
+
23
35
  /**
24
36
  * @namespace faiss
25
37
  *
@@ -38,8 +50,8 @@
38
50
 
39
51
  namespace faiss {
40
52
 
41
- /// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
42
- /// impl/DistanceComputer.h
53
+ /// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
54
+ /// and impl/DistanceComputer.h
43
55
  struct IDSelector;
44
56
  struct RangeSearchResult;
45
57
  struct DistanceComputer;
@@ -56,7 +68,8 @@ struct SearchParameters {
56
68
  virtual ~SearchParameters() {}
57
69
  };
58
70
 
59
- /** Abstract structure for an index, supports adding vectors and searching them.
71
+ /** Abstract structure for an index, supports adding vectors and searching
72
+ * them.
60
73
  *
61
74
  * All vectors provided at add or search time are 32-bit float arrays,
62
75
  * although the internal representation may vary.
@@ -154,7 +167,8 @@ struct Index {
154
167
 
155
168
  /** return the indexes of the k vectors closest to the query x.
156
169
  *
157
- * This function is identical as search but only return labels of neighbors.
170
+ * This function is identical as search but only return labels of
171
+ * neighbors.
158
172
  * @param n number of vectors
159
173
  * @param x input vectors to search, size n * d
160
174
  * @param labels output labels of the NNs, size n*k
@@ -179,7 +193,8 @@ struct Index {
179
193
  */
180
194
  virtual void reconstruct(idx_t key, float* recons) const;
181
195
 
182
- /** Reconstruct several stored vectors (or an approximation if lossy coding)
196
+ /** Reconstruct several stored vectors (or an approximation if lossy
197
+ * coding)
183
198
  *
184
199
  * this function may not be defined for some indexes
185
200
  * @param n number of vectors to reconstruct
@@ -21,7 +21,7 @@ namespace faiss {
21
21
  struct IndexBinaryHNSW : IndexBinary {
22
22
  typedef HNSW::storage_idx_t storage_idx_t;
23
23
 
24
- // the link strcuture
24
+ // the link structure
25
25
  HNSW hnsw;
26
26
 
27
27
  // the sequential storage
@@ -456,7 +456,7 @@ void search_knn_hamming_heap(
456
456
  }
457
457
 
458
458
  } // parallel for
459
- } // parallel
459
+ } // parallel
460
460
 
461
461
  indexIVF_stats.nq += n;
462
462
  indexIVF_stats.nlist += nlistv;
@@ -189,6 +189,7 @@ void estimators_from_tables_generic(
189
189
  dt += index.ksub;
190
190
  }
191
191
  }
192
+
192
193
  if (C::cmp(heap_dis[0], dis)) {
193
194
  heap_pop<C>(k, heap_dis, heap_ids);
194
195
  heap_push<C>(k, heap_dis, heap_ids, dis, j);
@@ -203,17 +204,18 @@ ResultHandlerCompare<C, false>* make_knn_handler(
203
204
  idx_t k,
204
205
  size_t ntotal,
205
206
  float* distances,
206
- idx_t* labels) {
207
+ idx_t* labels,
208
+ const IDSelector* sel = nullptr) {
207
209
  using HeapHC = HeapHandler<C, false>;
208
210
  using ReservoirHC = ReservoirHandler<C, false>;
209
211
  using SingleResultHC = SingleResultHandler<C, false>;
210
212
 
211
213
  if (k == 1) {
212
- return new SingleResultHC(n, ntotal, distances, labels);
214
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
213
215
  } else if (impl % 2 == 0) {
214
- return new HeapHC(n, ntotal, k, distances, labels);
216
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
215
217
  } else /* if (impl % 2 == 1) */ {
216
- return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
218
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
217
219
  }
218
220
  }
219
221
 
@@ -547,6 +549,22 @@ void IndexFastScan::search_implem_14(
547
549
  }
548
550
  }
549
551
 
552
+ template void IndexFastScan::search_dispatch_implem<true>(
553
+ idx_t n,
554
+ const float* x,
555
+ idx_t k,
556
+ float* distances,
557
+ idx_t* labels,
558
+ const NormTableScaler* scaler) const;
559
+
560
+ template void IndexFastScan::search_dispatch_implem<false>(
561
+ idx_t n,
562
+ const float* x,
563
+ idx_t k,
564
+ float* distances,
565
+ idx_t* labels,
566
+ const NormTableScaler* scaler) const;
567
+
550
568
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
551
569
  std::vector<uint8_t> code(code_size, 0);
552
570
  BitstringWriter bsw(code.data(), code_size);
@@ -41,15 +41,19 @@ void IndexFlat::search(
41
41
  } else if (metric_type == METRIC_L2) {
42
42
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
43
43
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
44
- } else if (is_similarity_metric(metric_type)) {
45
- float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
46
- knn_extra_metrics(
47
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
48
44
  } else {
49
- FAISS_THROW_IF_NOT(!sel);
50
- float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
+ FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
51
46
  knn_extra_metrics(
52
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47
+ x,
48
+ get_xb(),
49
+ d,
50
+ n,
51
+ ntotal,
52
+ metric_type,
53
+ metric_arg,
54
+ k,
55
+ distances,
56
+ labels);
53
57
  }
54
58
  }
55
59
 
@@ -12,6 +12,8 @@
12
12
  #include <faiss/impl/DistanceComputer.h>
13
13
  #include <faiss/impl/FaissAssert.h>
14
14
  #include <faiss/impl/IDSelector.h>
15
+ #include <faiss/impl/ResultHandler.h>
16
+ #include <faiss/utils/extra_distances.h>
15
17
 
16
18
  namespace faiss {
17
19
 
@@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
70
72
  reconstruct_n(key, 1, recons);
71
73
  }
72
74
 
73
- FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
74
- const {
75
- FAISS_THROW_MSG("not implemented");
76
- }
77
-
78
75
  void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
79
76
  // minimal sanity checks
80
77
  const IndexFlatCodes* other =
@@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
114
111
  std::swap(codes, new_codes);
115
112
  }
116
113
 
114
+ namespace {
115
+
116
+ template <class VD>
117
+ struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
118
+ const IndexFlatCodes& codec;
119
+ const VD vd;
120
+ // temp buffers
121
+ std::vector<uint8_t> code_buffer;
122
+ std::vector<float> vec_buffer;
123
+ const float* query = nullptr;
124
+
125
+ GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
126
+ : FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
127
+ codec(*codec),
128
+ vd(vd),
129
+ code_buffer(codec->code_size * 4),
130
+ vec_buffer(codec->d * 4) {}
131
+
132
+ void set_query(const float* x) override {
133
+ query = x;
134
+ }
135
+
136
+ float operator()(idx_t i) override {
137
+ codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
138
+ return vd(query, vec_buffer.data());
139
+ }
140
+
141
+ float distance_to_code(const uint8_t* code) override {
142
+ codec.sa_decode(1, code, vec_buffer.data());
143
+ return vd(query, vec_buffer.data());
144
+ }
145
+
146
+ float symmetric_dis(idx_t i, idx_t j) override {
147
+ codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
148
+ codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
149
+ return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
150
+ }
151
+
152
+ void distances_batch_4(
153
+ const idx_t idx0,
154
+ const idx_t idx1,
155
+ const idx_t idx2,
156
+ const idx_t idx3,
157
+ float& dis0,
158
+ float& dis1,
159
+ float& dis2,
160
+ float& dis3) override {
161
+ uint8_t* cp = code_buffer.data();
162
+ for (idx_t i : {idx0, idx1, idx2, idx3}) {
163
+ memcpy(cp, codes + i * code_size, code_size);
164
+ cp += code_size;
165
+ }
166
+ // potential benefit is if batch decoding is more efficient than 1 by 1
167
+ // decoding
168
+ codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
169
+ dis0 = vd(query, vec_buffer.data());
170
+ dis1 = vd(query, vec_buffer.data() + vd.d);
171
+ dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
172
+ dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
173
+ }
174
+ };
175
+
176
+ struct Run_get_distance_computer {
177
+ using T = FlatCodesDistanceComputer*;
178
+
179
+ template <class VD>
180
+ FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
181
+ return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
182
+ }
183
+ };
184
+
185
+ template <class BlockResultHandler>
186
+ struct Run_search_with_decompress {
187
+ using T = void;
188
+
189
+ template <class VectorDistance>
190
+ void f(VectorDistance& vd,
191
+ const IndexFlatCodes* index_ptr,
192
+ const float* xq,
193
+ BlockResultHandler& res) {
194
+ // Note that there seems to be a clang (?) bug that "sometimes" passes
195
+ // the const Index & parameters by value, so to be on the safe side,
196
+ // it's better to use pointers.
197
+ const IndexFlatCodes& index = *index_ptr;
198
+ size_t ntotal = index.ntotal;
199
+ using SingleResultHandler =
200
+ typename BlockResultHandler::SingleResultHandler;
201
+ using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
202
+ #pragma omp parallel // if (res.nq > 100)
203
+ {
204
+ std::unique_ptr<DC> dc(new DC(&index, vd));
205
+ SingleResultHandler resi(res);
206
+ #pragma omp for
207
+ for (int64_t q = 0; q < res.nq; q++) {
208
+ resi.begin(q);
209
+ dc->set_query(xq + vd.d * q);
210
+ for (size_t i = 0; i < ntotal; i++) {
211
+ if (res.is_in_selection(i)) {
212
+ float dis = (*dc)(i);
213
+ resi.add_result(dis, i);
214
+ }
215
+ }
216
+ resi.end();
217
+ }
218
+ }
219
+ }
220
+ };
221
+
222
+ struct Run_search_with_decompress_res {
223
+ using T = void;
224
+
225
+ template <class ResultHandler>
226
+ void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
227
+ Run_search_with_decompress<ResultHandler> r;
228
+ dispatch_VectorDistance(
229
+ index->d,
230
+ index->metric_type,
231
+ index->metric_arg,
232
+ r,
233
+ index,
234
+ xq,
235
+ res);
236
+ }
237
+ };
238
+
239
+ } // anonymous namespace
240
+
241
+ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
242
+ const {
243
+ Run_get_distance_computer r;
244
+ return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
245
+ }
246
+
247
+ void IndexFlatCodes::search(
248
+ idx_t n,
249
+ const float* x,
250
+ idx_t k,
251
+ float* distances,
252
+ idx_t* labels,
253
+ const SearchParameters* params) const {
254
+ Run_search_with_decompress_res r;
255
+ const IDSelector* sel = params ? params->sel : nullptr;
256
+ dispatch_knn_ResultHandler(
257
+ n, distances, labels, k, metric_type, sel, r, this, x);
258
+ }
259
+
260
+ void IndexFlatCodes::range_search(
261
+ idx_t n,
262
+ const float* x,
263
+ float radius,
264
+ RangeSearchResult* result,
265
+ const SearchParameters* params) const {
266
+ const IDSelector* sel = params ? params->sel : nullptr;
267
+ Run_search_with_decompress_res r;
268
+ dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
269
+ }
270
+
117
271
  } // namespace faiss
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #pragma once
11
9
 
12
10
  #include <faiss/Index.h>
@@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
45
43
  * different from the usual ones: the new ids are shifted */
46
44
  size_t remove_ids(const IDSelector& sel) override;
47
45
 
48
- /** a FlatCodesDistanceComputer offers a distance_to_code method */
46
+ /** a FlatCodesDistanceComputer offers a distance_to_code method
47
+ *
48
+ * The default implementation explicitly decodes the vector with sa_decode.
49
+ */
49
50
  virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
50
51
 
51
52
  DistanceComputer* get_distance_computer() const override {
52
53
  return get_FlatCodesDistanceComputer();
53
54
  }
54
55
 
56
+ /** Search implemented by decoding */
57
+ void search(
58
+ idx_t n,
59
+ const float* x,
60
+ idx_t k,
61
+ float* distances,
62
+ idx_t* labels,
63
+ const SearchParameters* params = nullptr) const override;
64
+
65
+ void range_search(
66
+ idx_t n,
67
+ const float* x,
68
+ float radius,
69
+ RangeSearchResult* result,
70
+ const SearchParameters* params = nullptr) const override;
71
+
55
72
  // returns a new instance of a CodePacker
56
73
  CodePacker* get_CodePacker() const;
57
74