faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,145 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/utils/AlignedTable.h>
12
+
13
+ namespace faiss {
14
+
15
+ /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
16
+ *
17
+ * The codes are not stored sequentially but grouped in blocks of size bbs.
18
+ * This makes it possible to compute distances quickly with SIMD instructions.
19
+ * The trailing codes (padding codes that are added to complete the last code)
20
+ * are garbage.
21
+ *
22
+ * Implementations:
23
+ * 12: blocked loop with internal loop on Q with qbs
24
+ * 13: same with reservoir accumulator to store results
25
+ * 14: no qbs with heap accumulator
26
+ * 15: no qbs with reservoir accumulator
27
+ */
28
+
29
+ struct IndexFastScan : Index {
30
+ // implementation to select
31
+ int implem = 0;
32
+ // skip some parts of the computation (for timing)
33
+ int skip = 0;
34
+
35
+ // size of the kernel
36
+ int bbs; // set at build time
37
+ int qbs = 0; // query block size 0 = use default
38
+
39
+ // vector quantizer
40
+ size_t M;
41
+ size_t nbits;
42
+ size_t ksub;
43
+ size_t code_size;
44
+
45
+ // packed version of the codes
46
+ size_t ntotal2;
47
+ size_t M2;
48
+
49
+ AlignedTable<uint8_t> codes;
50
+
51
+ // this is for testing purposes only
52
+ // (set when initialized by IndexPQ or IndexAQ)
53
+ const uint8_t* orig_codes = nullptr;
54
+
55
+ void init_fastscan(
56
+ int d,
57
+ size_t M,
58
+ size_t nbits,
59
+ MetricType metric,
60
+ int bbs);
61
+
62
+ IndexFastScan();
63
+
64
+ void reset() override;
65
+
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels,
72
+ const SearchParameters* params = nullptr) const override;
73
+
74
+ void add(idx_t n, const float* x) override;
75
+
76
+ virtual void compute_codes(uint8_t* codes, idx_t n, const float* x)
77
+ const = 0;
78
+
79
+ virtual void compute_float_LUT(float* lut, idx_t n, const float* x)
80
+ const = 0;
81
+
82
+ // called by search function
83
+ void compute_quantized_LUT(
84
+ idx_t n,
85
+ const float* x,
86
+ uint8_t* lut,
87
+ float* normalizers) const;
88
+
89
+ template <bool is_max, class Scaler>
90
+ void search_dispatch_implem(
91
+ idx_t n,
92
+ const float* x,
93
+ idx_t k,
94
+ float* distances,
95
+ idx_t* labels,
96
+ const Scaler& scaler) const;
97
+
98
+ template <class Cfloat, class Scaler>
99
+ void search_implem_234(
100
+ idx_t n,
101
+ const float* x,
102
+ idx_t k,
103
+ float* distances,
104
+ idx_t* labels,
105
+ const Scaler& scaler) const;
106
+
107
+ template <class C, class Scaler>
108
+ void search_implem_12(
109
+ idx_t n,
110
+ const float* x,
111
+ idx_t k,
112
+ float* distances,
113
+ idx_t* labels,
114
+ int impl,
115
+ const Scaler& scaler) const;
116
+
117
+ template <class C, class Scaler>
118
+ void search_implem_14(
119
+ idx_t n,
120
+ const float* x,
121
+ idx_t k,
122
+ float* distances,
123
+ idx_t* labels,
124
+ int impl,
125
+ const Scaler& scaler) const;
126
+
127
+ void reconstruct(idx_t key, float* recons) const override;
128
+ size_t remove_ids(const IDSelector& sel) override;
129
+ void merge_from(Index& otherIndex, idx_t add_id = 0) override;
130
+ void check_compatible_for_merge(const Index& otherIndex) const override;
131
+ };
132
+
133
+ struct FastScanStats {
134
+ uint64_t t0, t1, t2, t3;
135
+ FastScanStats() {
136
+ reset();
137
+ }
138
+ void reset() {
139
+ memset(this, 0, sizeof(*this));
140
+ }
141
+ };
142
+
143
+ FAISS_API extern FastScanStats FastScan_stats;
144
+
145
+ } // namespace faiss
@@ -27,18 +27,20 @@ void IndexFlat::search(
27
27
  const float* x,
28
28
  idx_t k,
29
29
  float* distances,
30
- idx_t* labels) const {
30
+ idx_t* labels,
31
+ const SearchParameters* params) const {
32
+ IDSelector* sel = params ? params->sel : nullptr;
31
33
  FAISS_THROW_IF_NOT(k > 0);
32
34
 
33
35
  // we see the distances and labels as heaps
34
-
35
36
  if (metric_type == METRIC_INNER_PRODUCT) {
36
37
  float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
37
- knn_inner_product(x, get_xb(), d, n, ntotal, &res);
38
+ knn_inner_product(x, get_xb(), d, n, ntotal, &res, sel);
38
39
  } else if (metric_type == METRIC_L2) {
39
40
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
40
- knn_L2sqr(x, get_xb(), d, n, ntotal, &res);
41
+ knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
41
42
  } else {
43
+ FAISS_THROW_IF_NOT(!sel);
42
44
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
43
45
  knn_extra_metrics(
44
46
  x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
@@ -49,14 +51,17 @@ void IndexFlat::range_search(
49
51
  idx_t n,
50
52
  const float* x,
51
53
  float radius,
52
- RangeSearchResult* result) const {
54
+ RangeSearchResult* result,
55
+ const SearchParameters* params) const {
56
+ IDSelector* sel = params ? params->sel : nullptr;
57
+
53
58
  switch (metric_type) {
54
59
  case METRIC_INNER_PRODUCT:
55
60
  range_search_inner_product(
56
- x, get_xb(), d, n, ntotal, radius, result);
61
+ x, get_xb(), d, n, ntotal, radius, result, sel);
57
62
  break;
58
63
  case METRIC_L2:
59
- range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result);
64
+ range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result, sel);
60
65
  break;
61
66
  default:
62
67
  FAISS_THROW_MSG("metric type not supported");
@@ -83,16 +88,16 @@ void IndexFlat::compute_distance_subset(
83
88
 
84
89
  namespace {
85
90
 
86
- struct FlatL2Dis : DistanceComputer {
91
+ struct FlatL2Dis : FlatCodesDistanceComputer {
87
92
  size_t d;
88
93
  Index::idx_t nb;
89
94
  const float* q;
90
95
  const float* b;
91
96
  size_t ndis;
92
97
 
93
- float operator()(idx_t i) override {
98
+ float distance_to_code(const uint8_t* code) final {
94
99
  ndis++;
95
- return fvec_L2sqr(q, b + i * d, d);
100
+ return fvec_L2sqr(q, (float*)code, d);
96
101
  }
97
102
 
98
103
  float symmetric_dis(idx_t i, idx_t j) override {
@@ -100,7 +105,10 @@ struct FlatL2Dis : DistanceComputer {
100
105
  }
101
106
 
102
107
  explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
103
- : d(storage.d),
108
+ : FlatCodesDistanceComputer(
109
+ storage.codes.data(),
110
+ storage.code_size),
111
+ d(storage.d),
104
112
  nb(storage.ntotal),
105
113
  q(q),
106
114
  b(storage.get_xb()),
@@ -111,24 +119,27 @@ struct FlatL2Dis : DistanceComputer {
111
119
  }
112
120
  };
113
121
 
114
- struct FlatIPDis : DistanceComputer {
122
+ struct FlatIPDis : FlatCodesDistanceComputer {
115
123
  size_t d;
116
124
  Index::idx_t nb;
117
125
  const float* q;
118
126
  const float* b;
119
127
  size_t ndis;
120
128
 
121
- float operator()(idx_t i) override {
122
- ndis++;
123
- return fvec_inner_product(q, b + i * d, d);
124
- }
125
-
126
129
  float symmetric_dis(idx_t i, idx_t j) override {
127
130
  return fvec_inner_product(b + j * d, b + i * d, d);
128
131
  }
129
132
 
133
+ float distance_to_code(const uint8_t* code) final {
134
+ ndis++;
135
+ return fvec_inner_product(q, (float*)code, d);
136
+ }
137
+
130
138
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
131
- : d(storage.d),
139
+ : FlatCodesDistanceComputer(
140
+ storage.codes.data(),
141
+ storage.code_size),
142
+ d(storage.d),
132
143
  nb(storage.ntotal),
133
144
  q(q),
134
145
  b(storage.get_xb()),
@@ -141,7 +152,7 @@ struct FlatIPDis : DistanceComputer {
141
152
 
142
153
  } // namespace
143
154
 
144
- DistanceComputer* IndexFlat::get_distance_computer() const {
155
+ FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
145
156
  if (metric_type == METRIC_L2) {
146
157
  return new FlatL2Dis(*this);
147
158
  } else if (metric_type == METRIC_INNER_PRODUCT) {
@@ -202,9 +213,11 @@ void IndexFlat1D::search(
202
213
  const float* x,
203
214
  idx_t k,
204
215
  float* distances,
205
- idx_t* labels) const {
216
+ idx_t* labels,
217
+ const SearchParameters* params) const {
218
+ FAISS_THROW_IF_NOT_MSG(
219
+ !params, "search params not supported for this index");
206
220
  FAISS_THROW_IF_NOT(k > 0);
207
-
208
221
  FAISS_THROW_IF_NOT_MSG(
209
222
  perm.size() == ntotal, "Call update_permutation before search");
210
223
  const float* xb = get_xb();
@@ -25,13 +25,15 @@ struct IndexFlat : IndexFlatCodes {
25
25
  const float* x,
26
26
  idx_t k,
27
27
  float* distances,
28
- idx_t* labels) const override;
28
+ idx_t* labels,
29
+ const SearchParameters* params = nullptr) const override;
29
30
 
30
31
  void range_search(
31
32
  idx_t n,
32
33
  const float* x,
33
34
  float radius,
34
- RangeSearchResult* result) const override;
35
+ RangeSearchResult* result,
36
+ const SearchParameters* params = nullptr) const override;
35
37
 
36
38
  void reconstruct(idx_t key, float* recons) const override;
37
39
 
@@ -60,7 +62,7 @@ struct IndexFlat : IndexFlatCodes {
60
62
 
61
63
  IndexFlat() {}
62
64
 
63
- DistanceComputer* get_distance_computer() const override;
65
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
64
66
 
65
67
  /* The stanadlone codec interface (just memcopies in this case) */
66
68
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
@@ -100,7 +102,8 @@ struct IndexFlat1D : IndexFlatL2 {
100
102
  const float* x,
101
103
  idx_t k,
102
104
  float* distances,
103
- idx_t* labels) const override;
105
+ idx_t* labels,
106
+ const SearchParameters* params = nullptr) const override;
104
107
  };
105
108
 
106
109
  } // namespace faiss
@@ -8,7 +8,9 @@
8
8
  #include <faiss/IndexFlatCodes.h>
9
9
 
10
10
  #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/DistanceComputer.h>
11
12
  #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/impl/IDSelector.h>
12
14
 
13
15
  namespace faiss {
14
16
 
@@ -19,8 +21,11 @@ IndexFlatCodes::IndexFlatCodes() : code_size(0) {}
19
21
 
20
22
  void IndexFlatCodes::add(idx_t n, const float* x) {
21
23
  FAISS_THROW_IF_NOT(is_trained);
24
+ if (n == 0) {
25
+ return;
26
+ }
22
27
  codes.resize((ntotal + n) * code_size);
23
- sa_encode(n, x, &codes[ntotal * code_size]);
28
+ sa_encode(n, x, codes.data() + (ntotal * code_size));
24
29
  ntotal += n;
25
30
  }
26
31
 
@@ -64,4 +69,33 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
64
69
  reconstruct_n(key, 1, recons);
65
70
  }
66
71
 
72
+ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
73
+ const {
74
+ FAISS_THROW_MSG("not implemented");
75
+ }
76
+
77
+ void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
78
+ // minimal sanity checks
79
+ const IndexFlatCodes* other =
80
+ dynamic_cast<const IndexFlatCodes*>(&otherIndex);
81
+ FAISS_THROW_IF_NOT(other);
82
+ FAISS_THROW_IF_NOT(other->d == d);
83
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
84
+ FAISS_THROW_IF_NOT_MSG(
85
+ typeid(*this) == typeid(*other),
86
+ "can only merge indexes of the same type");
87
+ }
88
+
89
+ void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
90
+ FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatCodes index");
91
+ check_compatible_for_merge(otherIndex);
92
+ IndexFlatCodes* other = static_cast<IndexFlatCodes*>(&otherIndex);
93
+ codes.resize((ntotal + other->ntotal) * code_size);
94
+ memcpy(codes.data() + (ntotal * code_size),
95
+ other->codes.data(),
96
+ other->ntotal * code_size);
97
+ ntotal += other->ntotal;
98
+ other->reset();
99
+ }
100
+
67
101
  } // namespace faiss
@@ -10,6 +10,7 @@
10
10
  #pragma once
11
11
 
12
12
  #include <faiss/Index.h>
13
+ #include <faiss/impl/DistanceComputer.h>
13
14
  #include <vector>
14
15
 
15
16
  namespace faiss {
@@ -42,6 +43,17 @@ struct IndexFlatCodes : Index {
42
43
  * indexing structure, the semantics of this operation are
43
44
  * different from the usual ones: the new ids are shifted */
44
45
  size_t remove_ids(const IDSelector& sel) override;
46
+
47
+ /** a FlatCodesDistanceComputer offers a distance_to_code method */
48
+ virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
49
+
50
+ DistanceComputer* get_distance_computer() const override {
51
+ return get_FlatCodesDistanceComputer();
52
+ }
53
+
54
+ void check_compatible_for_merge(const Index& otherIndex) const override;
55
+
56
+ virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
45
57
  };
46
58
 
47
59
  } // namespace faiss
@@ -202,7 +202,10 @@ void hnsw_add_vertices(
202
202
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
203
203
  size_t counter = 0;
204
204
 
205
- #pragma omp for schedule(dynamic)
205
+ // here we should do schedule(dynamic) but this segfaults for
206
+ // some versions of LLVM. The performance impact should not be
207
+ // too large when (i1 - i0) / num_threads >> 1
208
+ #pragma omp for schedule(static)
206
209
  for (int i = i0; i < i1; i++) {
207
210
  storage_idx_t pt_id = order[i];
208
211
  dis->set_query(x + (pt_id - n0) * d);
@@ -219,7 +222,6 @@ void hnsw_add_vertices(
219
222
  printf(" %d / %d\r", i - i0, i1 - i0);
220
223
  fflush(stdout);
221
224
  }
222
-
223
225
  if (counter % check_period == 0) {
224
226
  if (InterruptCallback::is_interrupted()) {
225
227
  interrupt = true;
@@ -284,18 +286,24 @@ void IndexHNSW::search(
284
286
  const float* x,
285
287
  idx_t k,
286
288
  float* distances,
287
- idx_t* labels) const
288
-
289
- {
289
+ idx_t* labels,
290
+ const SearchParameters* params_in) const {
290
291
  FAISS_THROW_IF_NOT(k > 0);
291
-
292
292
  FAISS_THROW_IF_NOT_MSG(
293
293
  storage,
294
294
  "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
295
+ const SearchParametersHNSW* params = nullptr;
296
+
297
+ int efSearch = hnsw.efSearch;
298
+ if (params_in) {
299
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
300
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
301
+ efSearch = params->efSearch;
302
+ }
295
303
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
296
304
 
297
- idx_t check_period = InterruptCallback::get_period_hint(
298
- hnsw.max_level * d * hnsw.efSearch);
305
+ idx_t check_period =
306
+ InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
299
307
 
300
308
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
301
309
  idx_t i1 = std::min(i0 + check_period, n);
@@ -314,7 +322,7 @@ void IndexHNSW::search(
314
322
  dis->set_query(x + i * d);
315
323
 
316
324
  maxheap_heapify(k, simi, idxi);
317
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt);
325
+ HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
318
326
  n1 += stats.n1;
319
327
  n2 += stats.n2;
320
328
  n3 += stats.n3;
@@ -423,16 +431,15 @@ void IndexHNSW::search_level_0(
423
431
  FAISS_THROW_IF_NOT(nprobe > 0);
424
432
 
425
433
  storage_idx_t ntotal = hnsw.levels.size();
426
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
427
434
 
428
435
  #pragma omp parallel
429
436
  {
430
- DistanceComputer* qdis = storage_distance_computer(storage);
431
- ScopeDeleter1<DistanceComputer> del(qdis);
432
-
437
+ std::unique_ptr<DistanceComputer> qdis(
438
+ storage_distance_computer(storage));
439
+ HNSWStats search_stats;
433
440
  VisitedTable vt(ntotal);
434
441
 
435
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
442
+ #pragma omp for
436
443
  for (idx_t i = 0; i < n; i++) {
437
444
  idx_t* idxi = labels + i * k;
438
445
  float* simi = distances + i * k;
@@ -440,69 +447,24 @@ void IndexHNSW::search_level_0(
440
447
  qdis->set_query(x + i * d);
441
448
  maxheap_heapify(k, simi, idxi);
442
449
 
443
- if (search_type == 1) {
444
- int nres = 0;
445
-
446
- for (int j = 0; j < nprobe; j++) {
447
- storage_idx_t cj = nearest[i * nprobe + j];
448
-
449
- if (cj < 0)
450
- break;
451
-
452
- if (vt.get(cj))
453
- continue;
454
-
455
- int candidates_size = std::max(hnsw.efSearch, int(k));
456
- MinimaxHeap candidates(candidates_size);
457
-
458
- candidates.push(cj, nearest_d[i * nprobe + j]);
459
-
460
- HNSWStats search_stats;
461
- nres = hnsw.search_from_candidates(
462
- *qdis,
463
- k,
464
- idxi,
465
- simi,
466
- candidates,
467
- vt,
468
- search_stats,
469
- 0,
470
- nres);
471
- n1 += search_stats.n1;
472
- n2 += search_stats.n2;
473
- n3 += search_stats.n3;
474
- ndis += search_stats.ndis;
475
- nreorder += search_stats.nreorder;
476
- }
477
- } else if (search_type == 2) {
478
- int candidates_size = std::max(hnsw.efSearch, int(k));
479
- candidates_size = std::max(candidates_size, nprobe);
480
-
481
- MinimaxHeap candidates(candidates_size);
482
- for (int j = 0; j < nprobe; j++) {
483
- storage_idx_t cj = nearest[i * nprobe + j];
484
-
485
- if (cj < 0)
486
- break;
487
- candidates.push(cj, nearest_d[i * nprobe + j]);
488
- }
450
+ hnsw.search_level_0(
451
+ *qdis.get(),
452
+ k,
453
+ idxi,
454
+ simi,
455
+ nprobe,
456
+ nearest + i * nprobe,
457
+ nearest_d + i * nprobe,
458
+ search_type,
459
+ search_stats,
460
+ vt);
489
461
 
490
- HNSWStats search_stats;
491
- hnsw.search_from_candidates(
492
- *qdis, k, idxi, simi, candidates, vt, search_stats, 0);
493
- n1 += search_stats.n1;
494
- n2 += search_stats.n2;
495
- n3 += search_stats.n3;
496
- ndis += search_stats.ndis;
497
- nreorder += search_stats.nreorder;
498
- }
499
462
  vt.advance();
500
-
501
463
  maxheap_reorder(k, simi, idxi);
502
464
  }
465
+ #pragma omp critical
466
+ { hnsw_stats.combine(search_stats); }
503
467
  }
504
-
505
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
506
468
  }
507
469
 
508
470
  void IndexHNSW::init_level_0_from_knngraph(
@@ -1035,8 +997,11 @@ void IndexHNSW2Level::search(
1035
997
  const float* x,
1036
998
  idx_t k,
1037
999
  float* distances,
1038
- idx_t* labels) const {
1000
+ idx_t* labels,
1001
+ const SearchParameters* params) const {
1039
1002
  FAISS_THROW_IF_NOT(k > 0);
1003
+ FAISS_THROW_IF_NOT_MSG(
1004
+ !params, "search params not supported for this index");
1040
1005
 
1041
1006
  if (dynamic_cast<const Index2Layer*>(storage)) {
1042
1007
  IndexHNSW::search(n, x, k, distances, labels);
@@ -1095,74 +1060,37 @@ void IndexHNSW2Level::search(
1095
1060
  }
1096
1061
 
1097
1062
  candidates.clear();
1098
- // copy the upper_beam elements to candidates list
1099
-
1100
- int search_policy = 2;
1101
-
1102
- if (search_policy == 1) {
1103
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1104
- if (idxi[j] < 0)
1105
- break;
1106
- candidates.push(idxi[j], simi[j]);
1107
- // search_from_candidates adds them back
1108
- idxi[j] = -1;
1109
- simi[j] = HUGE_VAL;
1110
- }
1111
1063
 
1112
- // reorder from sorted to heap
1113
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1114
-
1115
- HNSWStats search_stats;
1116
- hnsw.search_from_candidates(
1117
- *dis,
1118
- k,
1119
- idxi,
1120
- simi,
1121
- candidates,
1122
- vt,
1123
- search_stats,
1124
- 0,
1125
- k);
1126
- n1 += search_stats.n1;
1127
- n2 += search_stats.n2;
1128
- n3 += search_stats.n3;
1129
- ndis += search_stats.ndis;
1130
- nreorder += search_stats.nreorder;
1131
-
1132
- vt.advance();
1133
-
1134
- } else if (search_policy == 2) {
1135
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1136
- if (idxi[j] < 0)
1137
- break;
1138
- candidates.push(idxi[j], simi[j]);
1139
- }
1140
-
1141
- // reorder from sorted to heap
1142
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1143
-
1144
- HNSWStats search_stats;
1145
- search_from_candidates_2(
1146
- hnsw,
1147
- *dis,
1148
- k,
1149
- idxi,
1150
- simi,
1151
- candidates,
1152
- vt,
1153
- search_stats,
1154
- 0,
1155
- k);
1156
- n1 += search_stats.n1;
1157
- n2 += search_stats.n2;
1158
- n3 += search_stats.n3;
1159
- ndis += search_stats.ndis;
1160
- nreorder += search_stats.nreorder;
1161
-
1162
- vt.advance();
1163
- vt.advance();
1064
+ for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1065
+ if (idxi[j] < 0)
1066
+ break;
1067
+ candidates.push(idxi[j], simi[j]);
1164
1068
  }
1165
1069
 
1070
+ // reorder from sorted to heap
1071
+ maxheap_heapify(k, simi, idxi, simi, idxi, k);
1072
+
1073
+ HNSWStats search_stats;
1074
+ search_from_candidates_2(
1075
+ hnsw,
1076
+ *dis,
1077
+ k,
1078
+ idxi,
1079
+ simi,
1080
+ candidates,
1081
+ vt,
1082
+ search_stats,
1083
+ 0,
1084
+ k);
1085
+ n1 += search_stats.n1;
1086
+ n2 += search_stats.n2;
1087
+ n3 += search_stats.n3;
1088
+ ndis += search_stats.ndis;
1089
+ nreorder += search_stats.nreorder;
1090
+
1091
+ vt.advance();
1092
+ vt.advance();
1093
+
1166
1094
  maxheap_reorder(k, simi, idxi);
1167
1095
  }
1168
1096
  }
@@ -96,7 +96,8 @@ struct IndexHNSW : Index {
96
96
  const float* x,
97
97
  idx_t k,
98
98
  float* distances,
99
- idx_t* labels) const override;
99
+ idx_t* labels,
100
+ const SearchParameters* params = nullptr) const override;
100
101
 
101
102
  void reconstruct(idx_t key, float* recons) const override;
102
103
 
@@ -180,7 +181,8 @@ struct IndexHNSW2Level : IndexHNSW {
180
181
  const float* x,
181
182
  idx_t k,
182
183
  float* distances,
183
- idx_t* labels) const override;
184
+ idx_t* labels,
185
+ const SearchParameters* params = nullptr) const override;
184
186
  };
185
187
 
186
188
  } // namespace faiss