faiss 0.3.1 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e41b15bbcda6c4d2a250df5b98d86e9baf51b34b90fc2fccb6f0a37f486ef417
4
- data.tar.gz: 768074275062ed45f1752e3a5c9d55a9695a6aa453e925aa0a6e607ce3215bab
3
+ metadata.gz: bdce4ec4f4169dff5f08ccbed2de2750dfd33738fe60d747645f7aaa43187505
4
+ data.tar.gz: a8ab702eead45525bb4aae8b28b9c20bc0d0d8c774a79ef942a9c8d7a9cabc2f
5
5
  SHA512:
6
- metadata.gz: cecc466dd24e03206219b63e750e48b554355c1c5dfc8e911879988a6f31eb628617133f5b584b3de29efcbe65d087cf5b4e219371cee959e8248c989a4dbffc
7
- data.tar.gz: 3e0c6be53825949f9c51a0195d85cbed87bc198dd06852c88c537b13e6bcc8e7fa65a3e3c88667eefef44e95278fe2c73ece89d5f92bd24f8c0d27b543488b56
6
+ metadata.gz: 7e8291961c8a8550e745c55eef5011ca23fc6f5ce7452eeb6da45ebfd020f7c07df70a0a5d7c281e2449214d5ec26102f9194f1aa49d0b9be21304dad3a98368
7
+ data.tar.gz: 80b475d06b237902b88025dc2602a7e7c8ad15ec757cd43d63d143423eb7a1bd759b8c30715b9ec30c2ae3cfecd2eea502e9814524219d396c71067f0959b62e
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.3.2 (2024-10-05)
2
+
3
+ - Updated Faiss to 1.9.0
4
+
1
5
  ## 0.3.1 (2024-03-13)
2
6
 
3
7
  - Updated Faiss to 1.8.0
data/lib/faiss/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Faiss
2
- VERSION = "0.3.1"
2
+ VERSION = "0.3.2"
3
3
  end
@@ -86,7 +86,7 @@ struct OperatingPoint {
86
86
  double perf; ///< performance measure (output of a Criterion)
87
87
  double t; ///< corresponding execution time (ms)
88
88
  std::string key; ///< key that identifies this op pt
89
- int64_t cno; ///< integer identifer
89
+ int64_t cno; ///< integer identifier
90
90
  };
91
91
 
92
92
  struct OperatingPoints {
@@ -11,6 +11,7 @@
11
11
  #include <faiss/VectorTransform.h>
12
12
  #include <faiss/impl/AuxIndexStructures.h>
13
13
 
14
+ #include <chrono>
14
15
  #include <cinttypes>
15
16
  #include <cmath>
16
17
  #include <cstdio>
@@ -74,6 +75,14 @@ void Clustering::train(
74
75
 
75
76
  namespace {
76
77
 
78
+ uint64_t get_actual_rng_seed(const int seed) {
79
+ return (seed >= 0)
80
+ ? seed
81
+ : static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
82
+ .time_since_epoch()
83
+ .count());
84
+ }
85
+
77
86
  idx_t subsample_training_set(
78
87
  const Clustering& clus,
79
88
  idx_t nx,
@@ -87,11 +96,30 @@ idx_t subsample_training_set(
87
96
  clus.k * clus.max_points_per_centroid,
88
97
  nx);
89
98
  }
90
- std::vector<int> perm(nx);
91
- rand_perm(perm.data(), nx, clus.seed);
99
+
100
+ const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
101
+
102
+ std::vector<int> perm;
103
+ if (clus.use_faster_subsampling) {
104
+ // use subsampling with splitmix64 rng
105
+ SplitMix64RandomGenerator rng(actual_seed);
106
+
107
+ const idx_t new_nx = clus.k * clus.max_points_per_centroid;
108
+ perm.resize(new_nx);
109
+ for (idx_t i = 0; i < new_nx; i++) {
110
+ perm[i] = rng.rand_int(nx);
111
+ }
112
+ } else {
113
+ // use subsampling with a default std rng
114
+ perm.resize(nx);
115
+ rand_perm(perm.data(), nx, actual_seed);
116
+ }
117
+
92
118
  nx = clus.k * clus.max_points_per_centroid;
93
119
  uint8_t* x_new = new uint8_t[nx * line_size];
94
120
  *x_out = x_new;
121
+
122
+ // might be worth omp-ing as well
95
123
  for (idx_t i = 0; i < nx; i++) {
96
124
  memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
97
125
  }
@@ -280,7 +308,7 @@ void Clustering::train_encoded(
280
308
 
281
309
  double t0 = getmillisecs();
282
310
 
283
- if (!codec) {
311
+ if (!codec && check_input_data_for_NaNs) {
284
312
  // Check for NaNs in input data. Normally it is the user's
285
313
  // responsibility, but it may spare us some hard-to-debug
286
314
  // reports.
@@ -383,6 +411,9 @@ void Clustering::train_encoded(
383
411
  }
384
412
  t0 = getmillisecs();
385
413
 
414
+ // initialize seed
415
+ const uint64_t actual_seed = get_actual_rng_seed(seed);
416
+
386
417
  // temporary buffer to decode vectors during the optimization
387
418
  std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
388
419
 
@@ -395,7 +426,7 @@ void Clustering::train_encoded(
395
426
  centroids.resize(d * k);
396
427
  std::vector<int> perm(nx);
397
428
 
398
- rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
429
+ rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
399
430
 
400
431
  if (!codec) {
401
432
  for (int i = n_input_centroids; i < k; i++) {
@@ -43,11 +43,20 @@ struct ClusteringParameters {
43
43
  int min_points_per_centroid = 39;
44
44
  /// to limit size of dataset, otherwise the training set is subsampled
45
45
  int max_points_per_centroid = 256;
46
- /// seed for the random number generator
46
+ /// seed for the random number generator.
47
+ /// negative values lead to seeding an internal rng with
48
+ /// std::high_resolution_clock.
47
49
  int seed = 1234;
48
50
 
49
51
  /// when the training set is encoded, batch size of the codec decoder
50
52
  size_t decode_block_size = 32768;
53
+
54
+ /// whether to check for NaNs in an input data
55
+ bool check_input_data_for_NaNs = true;
56
+
57
+ /// Whether to use splitmix64-based random number generator for subsampling,
58
+ /// which is faster, but may pick duplicate points.
59
+ bool use_faster_subsampling = false;
51
60
  };
52
61
 
53
62
  struct ClusteringIterationStats {
@@ -352,7 +352,10 @@ void search_with_parameters(
352
352
  const IndexIVF* index_ivf = dynamic_cast<const IndexIVF*>(index);
353
353
  FAISS_THROW_IF_NOT(index_ivf);
354
354
 
355
- index_ivf->quantizer->search(n, x, params->nprobe, Dq.data(), Iq.data());
355
+ SearchParameters* quantizer_params =
356
+ (params) ? params->quantizer_params : nullptr;
357
+ index_ivf->quantizer->search(
358
+ n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params);
356
359
 
357
360
  if (nb_dis_ptr) {
358
361
  *nb_dis_ptr = count_ndis(index_ivf, n * params->nprobe, Iq.data());
@@ -17,9 +17,21 @@
17
17
  #include <typeinfo>
18
18
 
19
19
  #define FAISS_VERSION_MAJOR 1
20
- #define FAISS_VERSION_MINOR 8
20
+ #define FAISS_VERSION_MINOR 9
21
21
  #define FAISS_VERSION_PATCH 0
22
22
 
23
+ // Macro to combine the version components into a single string
24
+ #ifndef FAISS_STRINGIFY
25
+ #define FAISS_STRINGIFY(ARG) #ARG
26
+ #endif
27
+ #ifndef FAISS_TOSTRING
28
+ #define FAISS_TOSTRING(ARG) FAISS_STRINGIFY(ARG)
29
+ #endif
30
+ #define VERSION_STRING \
31
+ FAISS_TOSTRING(FAISS_VERSION_MAJOR) \
32
+ "." FAISS_TOSTRING(FAISS_VERSION_MINOR) "." FAISS_TOSTRING( \
33
+ FAISS_VERSION_PATCH)
34
+
23
35
  /**
24
36
  * @namespace faiss
25
37
  *
@@ -38,8 +50,8 @@
38
50
 
39
51
  namespace faiss {
40
52
 
41
- /// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
42
- /// impl/DistanceComputer.h
53
+ /// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
54
+ /// and impl/DistanceComputer.h
43
55
  struct IDSelector;
44
56
  struct RangeSearchResult;
45
57
  struct DistanceComputer;
@@ -56,7 +68,8 @@ struct SearchParameters {
56
68
  virtual ~SearchParameters() {}
57
69
  };
58
70
 
59
- /** Abstract structure for an index, supports adding vectors and searching them.
71
+ /** Abstract structure for an index, supports adding vectors and searching
72
+ * them.
60
73
  *
61
74
  * All vectors provided at add or search time are 32-bit float arrays,
62
75
  * although the internal representation may vary.
@@ -154,7 +167,8 @@ struct Index {
154
167
 
155
168
  /** return the indexes of the k vectors closest to the query x.
156
169
  *
157
- * This function is identical as search but only return labels of neighbors.
170
+ * This function is identical as search but only return labels of
171
+ * neighbors.
158
172
  * @param n number of vectors
159
173
  * @param x input vectors to search, size n * d
160
174
  * @param labels output labels of the NNs, size n*k
@@ -179,7 +193,8 @@ struct Index {
179
193
  */
180
194
  virtual void reconstruct(idx_t key, float* recons) const;
181
195
 
182
- /** Reconstruct several stored vectors (or an approximation if lossy coding)
196
+ /** Reconstruct several stored vectors (or an approximation if lossy
197
+ * coding)
183
198
  *
184
199
  * this function may not be defined for some indexes
185
200
  * @param n number of vectors to reconstruct
@@ -21,7 +21,7 @@ namespace faiss {
21
21
  struct IndexBinaryHNSW : IndexBinary {
22
22
  typedef HNSW::storage_idx_t storage_idx_t;
23
23
 
24
- // the link strcuture
24
+ // the link structure
25
25
  HNSW hnsw;
26
26
 
27
27
  // the sequential storage
@@ -456,7 +456,7 @@ void search_knn_hamming_heap(
456
456
  }
457
457
 
458
458
  } // parallel for
459
- } // parallel
459
+ } // parallel
460
460
 
461
461
  indexIVF_stats.nq += n;
462
462
  indexIVF_stats.nlist += nlistv;
@@ -189,6 +189,7 @@ void estimators_from_tables_generic(
189
189
  dt += index.ksub;
190
190
  }
191
191
  }
192
+
192
193
  if (C::cmp(heap_dis[0], dis)) {
193
194
  heap_pop<C>(k, heap_dis, heap_ids);
194
195
  heap_push<C>(k, heap_dis, heap_ids, dis, j);
@@ -203,17 +204,18 @@ ResultHandlerCompare<C, false>* make_knn_handler(
203
204
  idx_t k,
204
205
  size_t ntotal,
205
206
  float* distances,
206
- idx_t* labels) {
207
+ idx_t* labels,
208
+ const IDSelector* sel = nullptr) {
207
209
  using HeapHC = HeapHandler<C, false>;
208
210
  using ReservoirHC = ReservoirHandler<C, false>;
209
211
  using SingleResultHC = SingleResultHandler<C, false>;
210
212
 
211
213
  if (k == 1) {
212
- return new SingleResultHC(n, ntotal, distances, labels);
214
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
213
215
  } else if (impl % 2 == 0) {
214
- return new HeapHC(n, ntotal, k, distances, labels);
216
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
215
217
  } else /* if (impl % 2 == 1) */ {
216
- return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
218
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
217
219
  }
218
220
  }
219
221
 
@@ -547,6 +549,22 @@ void IndexFastScan::search_implem_14(
547
549
  }
548
550
  }
549
551
 
552
+ template void IndexFastScan::search_dispatch_implem<true>(
553
+ idx_t n,
554
+ const float* x,
555
+ idx_t k,
556
+ float* distances,
557
+ idx_t* labels,
558
+ const NormTableScaler* scaler) const;
559
+
560
+ template void IndexFastScan::search_dispatch_implem<false>(
561
+ idx_t n,
562
+ const float* x,
563
+ idx_t k,
564
+ float* distances,
565
+ idx_t* labels,
566
+ const NormTableScaler* scaler) const;
567
+
550
568
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
551
569
  std::vector<uint8_t> code(code_size, 0);
552
570
  BitstringWriter bsw(code.data(), code_size);
@@ -41,15 +41,19 @@ void IndexFlat::search(
41
41
  } else if (metric_type == METRIC_L2) {
42
42
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
43
43
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
44
- } else if (is_similarity_metric(metric_type)) {
45
- float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
46
- knn_extra_metrics(
47
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
48
44
  } else {
49
- FAISS_THROW_IF_NOT(!sel);
50
- float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
+ FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
51
46
  knn_extra_metrics(
52
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47
+ x,
48
+ get_xb(),
49
+ d,
50
+ n,
51
+ ntotal,
52
+ metric_type,
53
+ metric_arg,
54
+ k,
55
+ distances,
56
+ labels);
53
57
  }
54
58
  }
55
59
 
@@ -12,6 +12,8 @@
12
12
  #include <faiss/impl/DistanceComputer.h>
13
13
  #include <faiss/impl/FaissAssert.h>
14
14
  #include <faiss/impl/IDSelector.h>
15
+ #include <faiss/impl/ResultHandler.h>
16
+ #include <faiss/utils/extra_distances.h>
15
17
 
16
18
  namespace faiss {
17
19
 
@@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
70
72
  reconstruct_n(key, 1, recons);
71
73
  }
72
74
 
73
- FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
74
- const {
75
- FAISS_THROW_MSG("not implemented");
76
- }
77
-
78
75
  void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
79
76
  // minimal sanity checks
80
77
  const IndexFlatCodes* other =
@@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
114
111
  std::swap(codes, new_codes);
115
112
  }
116
113
 
114
+ namespace {
115
+
116
+ template <class VD>
117
+ struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
118
+ const IndexFlatCodes& codec;
119
+ const VD vd;
120
+ // temp buffers
121
+ std::vector<uint8_t> code_buffer;
122
+ std::vector<float> vec_buffer;
123
+ const float* query = nullptr;
124
+
125
+ GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
126
+ : FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
127
+ codec(*codec),
128
+ vd(vd),
129
+ code_buffer(codec->code_size * 4),
130
+ vec_buffer(codec->d * 4) {}
131
+
132
+ void set_query(const float* x) override {
133
+ query = x;
134
+ }
135
+
136
+ float operator()(idx_t i) override {
137
+ codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
138
+ return vd(query, vec_buffer.data());
139
+ }
140
+
141
+ float distance_to_code(const uint8_t* code) override {
142
+ codec.sa_decode(1, code, vec_buffer.data());
143
+ return vd(query, vec_buffer.data());
144
+ }
145
+
146
+ float symmetric_dis(idx_t i, idx_t j) override {
147
+ codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
148
+ codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
149
+ return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
150
+ }
151
+
152
+ void distances_batch_4(
153
+ const idx_t idx0,
154
+ const idx_t idx1,
155
+ const idx_t idx2,
156
+ const idx_t idx3,
157
+ float& dis0,
158
+ float& dis1,
159
+ float& dis2,
160
+ float& dis3) override {
161
+ uint8_t* cp = code_buffer.data();
162
+ for (idx_t i : {idx0, idx1, idx2, idx3}) {
163
+ memcpy(cp, codes + i * code_size, code_size);
164
+ cp += code_size;
165
+ }
166
+ // potential benefit is if batch decoding is more efficient than 1 by 1
167
+ // decoding
168
+ codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
169
+ dis0 = vd(query, vec_buffer.data());
170
+ dis1 = vd(query, vec_buffer.data() + vd.d);
171
+ dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
172
+ dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
173
+ }
174
+ };
175
+
176
+ struct Run_get_distance_computer {
177
+ using T = FlatCodesDistanceComputer*;
178
+
179
+ template <class VD>
180
+ FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
181
+ return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
182
+ }
183
+ };
184
+
185
+ template <class BlockResultHandler>
186
+ struct Run_search_with_decompress {
187
+ using T = void;
188
+
189
+ template <class VectorDistance>
190
+ void f(VectorDistance& vd,
191
+ const IndexFlatCodes* index_ptr,
192
+ const float* xq,
193
+ BlockResultHandler& res) {
194
+ // Note that there seems to be a clang (?) bug that "sometimes" passes
195
+ // the const Index & parameters by value, so to be on the safe side,
196
+ // it's better to use pointers.
197
+ const IndexFlatCodes& index = *index_ptr;
198
+ size_t ntotal = index.ntotal;
199
+ using SingleResultHandler =
200
+ typename BlockResultHandler::SingleResultHandler;
201
+ using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
202
+ #pragma omp parallel // if (res.nq > 100)
203
+ {
204
+ std::unique_ptr<DC> dc(new DC(&index, vd));
205
+ SingleResultHandler resi(res);
206
+ #pragma omp for
207
+ for (int64_t q = 0; q < res.nq; q++) {
208
+ resi.begin(q);
209
+ dc->set_query(xq + vd.d * q);
210
+ for (size_t i = 0; i < ntotal; i++) {
211
+ if (res.is_in_selection(i)) {
212
+ float dis = (*dc)(i);
213
+ resi.add_result(dis, i);
214
+ }
215
+ }
216
+ resi.end();
217
+ }
218
+ }
219
+ }
220
+ };
221
+
222
+ struct Run_search_with_decompress_res {
223
+ using T = void;
224
+
225
+ template <class ResultHandler>
226
+ void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
227
+ Run_search_with_decompress<ResultHandler> r;
228
+ dispatch_VectorDistance(
229
+ index->d,
230
+ index->metric_type,
231
+ index->metric_arg,
232
+ r,
233
+ index,
234
+ xq,
235
+ res);
236
+ }
237
+ };
238
+
239
+ } // anonymous namespace
240
+
241
+ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
242
+ const {
243
+ Run_get_distance_computer r;
244
+ return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
245
+ }
246
+
247
+ void IndexFlatCodes::search(
248
+ idx_t n,
249
+ const float* x,
250
+ idx_t k,
251
+ float* distances,
252
+ idx_t* labels,
253
+ const SearchParameters* params) const {
254
+ Run_search_with_decompress_res r;
255
+ const IDSelector* sel = params ? params->sel : nullptr;
256
+ dispatch_knn_ResultHandler(
257
+ n, distances, labels, k, metric_type, sel, r, this, x);
258
+ }
259
+
260
+ void IndexFlatCodes::range_search(
261
+ idx_t n,
262
+ const float* x,
263
+ float radius,
264
+ RangeSearchResult* result,
265
+ const SearchParameters* params) const {
266
+ const IDSelector* sel = params ? params->sel : nullptr;
267
+ Run_search_with_decompress_res r;
268
+ dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
269
+ }
270
+
117
271
  } // namespace faiss
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #pragma once
11
9
 
12
10
  #include <faiss/Index.h>
@@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
45
43
  * different from the usual ones: the new ids are shifted */
46
44
  size_t remove_ids(const IDSelector& sel) override;
47
45
 
48
- /** a FlatCodesDistanceComputer offers a distance_to_code method */
46
+ /** a FlatCodesDistanceComputer offers a distance_to_code method
47
+ *
48
+ * The default implementation explicitly decodes the vector with sa_decode.
49
+ */
49
50
  virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
50
51
 
51
52
  DistanceComputer* get_distance_computer() const override {
52
53
  return get_FlatCodesDistanceComputer();
53
54
  }
54
55
 
56
+ /** Search implemented by decoding */
57
+ void search(
58
+ idx_t n,
59
+ const float* x,
60
+ idx_t k,
61
+ float* distances,
62
+ idx_t* labels,
63
+ const SearchParameters* params = nullptr) const override;
64
+
65
+ void range_search(
66
+ idx_t n,
67
+ const float* x,
68
+ float radius,
69
+ RangeSearchResult* result,
70
+ const SearchParameters* params = nullptr) const override;
71
+
55
72
  // returns a new instance of a CodePacker
56
73
  CodePacker* get_CodePacker() const;
57
74