faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -0,0 +1,213 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <memory>
11
+
12
+ #include <faiss/IndexIVF.h>
13
+ #include <faiss/utils/AlignedTable.h>
14
+
15
+ namespace faiss {
16
+
17
+ /** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now.
18
+ *
19
+ * The codes in the inverted lists are not stored sequentially but
20
+ * grouped in blocks of size bbs. This makes it possible to very quickly
21
+ * compute distances with SIMD instructions.
22
+ *
23
+ * Implementations (implem):
24
+ * 0: auto-select implementation (default)
25
+ * 1: orig's search, re-implemented
26
+ * 2: orig's search, re-ordered by invlist
27
+ * 10: optimizer int16 search, collect results in heap, no qbs
28
+ * 11: idem, collect results in reservoir
29
+ * 12: optimizer int16 search, collect results in heap, uses qbs
30
+ * 13: idem, collect results in reservoir
31
+ */
32
+
33
+ struct IndexIVFFastScan : IndexIVF {
34
+ // size of the kernel
35
+ int bbs; // set at build time
36
+
37
+ size_t M;
38
+ size_t nbits;
39
+ size_t ksub;
40
+
41
+ // M rounded up to a multiple of 2
42
+ size_t M2;
43
+
44
+ // search-time implementation
45
+ int implem = 0;
46
+ // skip some parts of the computation (for timing)
47
+ int skip = 0;
48
+ bool by_residual = false;
49
+
50
+ // batching factors at search time (0 = default)
51
+ int qbs = 0;
52
+ size_t qbs2 = 0;
53
+
54
+ IndexIVFFastScan(
55
+ Index* quantizer,
56
+ size_t d,
57
+ size_t nlist,
58
+ size_t code_size,
59
+ MetricType metric = METRIC_L2);
60
+
61
+ IndexIVFFastScan();
62
+
63
+ void init_fastscan(
64
+ size_t M,
65
+ size_t nbits,
66
+ size_t nlist,
67
+ MetricType metric,
68
+ int bbs);
69
+
70
+ ~IndexIVFFastScan() override;
71
+
72
+ /// orig's inverted lists (for debugging)
73
+ InvertedLists* orig_invlists = nullptr;
74
+
75
+ void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
76
+
77
+ // prepare look-up tables
78
+
79
+ virtual bool lookup_table_is_3d() const = 0;
80
+
81
+ virtual void compute_LUT(
82
+ size_t n,
83
+ const float* x,
84
+ const idx_t* coarse_ids,
85
+ const float* coarse_dis,
86
+ AlignedTable<float>& dis_tables,
87
+ AlignedTable<float>& biases) const = 0;
88
+
89
+ void compute_LUT_uint8(
90
+ size_t n,
91
+ const float* x,
92
+ const idx_t* coarse_ids,
93
+ const float* coarse_dis,
94
+ AlignedTable<uint8_t>& dis_tables,
95
+ AlignedTable<uint16_t>& biases,
96
+ float* normalizers) const;
97
+
98
+ void search(
99
+ idx_t n,
100
+ const float* x,
101
+ idx_t k,
102
+ float* distances,
103
+ idx_t* labels,
104
+ const SearchParameters* params = nullptr) const override;
105
+
106
+ /// will just fail
107
+ void range_search(
108
+ idx_t n,
109
+ const float* x,
110
+ float radius,
111
+ RangeSearchResult* result,
112
+ const SearchParameters* params = nullptr) const override;
113
+
114
+ // internal search funcs
115
+
116
+ template <bool is_max, class Scaler>
117
+ void search_dispatch_implem(
118
+ idx_t n,
119
+ const float* x,
120
+ idx_t k,
121
+ float* distances,
122
+ idx_t* labels,
123
+ const Scaler& scaler) const;
124
+
125
+ template <class C, class Scaler>
126
+ void search_implem_1(
127
+ idx_t n,
128
+ const float* x,
129
+ idx_t k,
130
+ float* distances,
131
+ idx_t* labels,
132
+ const Scaler& scaler) const;
133
+
134
+ template <class C, class Scaler>
135
+ void search_implem_2(
136
+ idx_t n,
137
+ const float* x,
138
+ idx_t k,
139
+ float* distances,
140
+ idx_t* labels,
141
+ const Scaler& scaler) const;
142
+
143
+ // implem 10 and 12 are not multithreaded internally, so
144
+ // export search stats
145
+ template <class C, class Scaler>
146
+ void search_implem_10(
147
+ idx_t n,
148
+ const float* x,
149
+ idx_t k,
150
+ float* distances,
151
+ idx_t* labels,
152
+ int impl,
153
+ size_t* ndis_out,
154
+ size_t* nlist_out,
155
+ const Scaler& scaler) const;
156
+
157
+ template <class C, class Scaler>
158
+ void search_implem_12(
159
+ idx_t n,
160
+ const float* x,
161
+ idx_t k,
162
+ float* distances,
163
+ idx_t* labels,
164
+ int impl,
165
+ size_t* ndis_out,
166
+ size_t* nlist_out,
167
+ const Scaler& scaler) const;
168
+
169
+ // implem 14 is mukltithreaded internally across nprobes and queries
170
+ template <class C, class Scaler>
171
+ void search_implem_14(
172
+ idx_t n,
173
+ const float* x,
174
+ idx_t k,
175
+ float* distances,
176
+ idx_t* labels,
177
+ int impl,
178
+ const Scaler& scaler) const;
179
+
180
+ // reconstruct vectors from packed invlists
181
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
182
+ const override;
183
+
184
+ // reconstruct orig invlists (for debugging)
185
+ void reconstruct_orig_invlists();
186
+ };
187
+
188
+ struct IVFFastScanStats {
189
+ uint64_t times[10];
190
+ uint64_t t_compute_distance_tables, t_round;
191
+ uint64_t t_copy_pack, t_scan, t_to_flat;
192
+ uint64_t reservoir_times[4];
193
+ double t_aq_encode;
194
+ double t_aq_norm_encode;
195
+
196
+ double Mcy_at(int i) {
197
+ return times[i] / (1000 * 1000.0);
198
+ }
199
+
200
+ double Mcy_reservoir_at(int i) {
201
+ return reservoir_times[i] / (1000 * 1000.0);
202
+ }
203
+ IVFFastScanStats() {
204
+ reset();
205
+ }
206
+ void reset() {
207
+ memset(this, 0, sizeof(*this));
208
+ }
209
+ };
210
+
211
+ FAISS_API extern IVFFastScanStats IVFFastScan_stats;
212
+
213
+ } // namespace faiss
@@ -17,6 +17,8 @@
17
17
  #include <faiss/IndexFlat.h>
18
18
 
19
19
  #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/IDSelector.h>
21
+
20
22
  #include <faiss/impl/FaissAssert.h>
21
23
  #include <faiss/utils/distances.h>
22
24
  #include <faiss/utils/utils.h>
@@ -40,9 +42,7 @@ void IndexIVFFlat::add_core(
40
42
  idx_t n,
41
43
  const float* x,
42
44
  const int64_t* xids,
43
- const int64_t* coarse_idx)
44
-
45
- {
45
+ const int64_t* coarse_idx) {
46
46
  FAISS_THROW_IF_NOT(is_trained);
47
47
  FAISS_THROW_IF_NOT(coarse_idx);
48
48
  assert(invlists);
@@ -118,20 +118,18 @@ void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
118
118
 
119
119
  namespace {
120
120
 
121
- template <MetricType metric, class C>
121
+ template <MetricType metric, class C, bool use_sel>
122
122
  struct IVFFlatScanner : InvertedListScanner {
123
123
  size_t d;
124
- bool store_pairs;
125
124
 
126
- IVFFlatScanner(size_t d, bool store_pairs)
127
- : d(d), store_pairs(store_pairs) {}
125
+ IVFFlatScanner(size_t d, bool store_pairs, const IDSelector* sel)
126
+ : InvertedListScanner(store_pairs, sel), d(d) {}
128
127
 
129
128
  const float* xi;
130
129
  void set_query(const float* query) override {
131
130
  this->xi = query;
132
131
  }
133
132
 
134
- idx_t list_no;
135
133
  void set_list(idx_t list_no, float /* coarse_dis */) override {
136
134
  this->list_no = list_no;
137
135
  }
@@ -155,6 +153,9 @@ struct IVFFlatScanner : InvertedListScanner {
155
153
  size_t nup = 0;
156
154
  for (size_t j = 0; j < list_size; j++) {
157
155
  const float* yj = list_vecs + d * j;
156
+ if (use_sel && !sel->is_member(ids[j])) {
157
+ continue;
158
+ }
158
159
  float dis = metric == METRIC_INNER_PRODUCT
159
160
  ? fvec_inner_product(xi, yj, d)
160
161
  : fvec_L2sqr(xi, yj, d);
@@ -176,6 +177,9 @@ struct IVFFlatScanner : InvertedListScanner {
176
177
  const float* list_vecs = (const float*)codes;
177
178
  for (size_t j = 0; j < list_size; j++) {
178
179
  const float* yj = list_vecs + d * j;
180
+ if (use_sel && !sel->is_member(ids[j])) {
181
+ continue;
182
+ }
179
183
  float dis = metric == METRIC_INNER_PRODUCT
180
184
  ? fvec_inner_product(xi, yj, d)
181
185
  : fvec_L2sqr(xi, yj, d);
@@ -187,20 +191,34 @@ struct IVFFlatScanner : InvertedListScanner {
187
191
  }
188
192
  };
189
193
 
194
+ template <bool use_sel>
195
+ InvertedListScanner* get_InvertedListScanner1(
196
+ const IndexIVFFlat* ivf,
197
+ bool store_pairs,
198
+ const IDSelector* sel) {
199
+ if (ivf->metric_type == METRIC_INNER_PRODUCT) {
200
+ return new IVFFlatScanner<
201
+ METRIC_INNER_PRODUCT,
202
+ CMin<float, int64_t>,
203
+ use_sel>(ivf->d, store_pairs, sel);
204
+ } else if (ivf->metric_type == METRIC_L2) {
205
+ return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>(
206
+ ivf->d, store_pairs, sel);
207
+ } else {
208
+ FAISS_THROW_MSG("metric type not supported");
209
+ }
210
+ }
211
+
190
212
  } // anonymous namespace
191
213
 
192
214
  InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
193
- bool store_pairs) const {
194
- if (metric_type == METRIC_INNER_PRODUCT) {
195
- return new IVFFlatScanner<METRIC_INNER_PRODUCT, CMin<float, int64_t>>(
196
- d, store_pairs);
197
- } else if (metric_type == METRIC_L2) {
198
- return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>>(
199
- d, store_pairs);
215
+ bool store_pairs,
216
+ const IDSelector* sel) const {
217
+ if (sel) {
218
+ return get_InvertedListScanner1<true>(this, store_pairs, sel);
200
219
  } else {
201
- FAISS_THROW_MSG("metric type not supported");
220
+ return get_InvertedListScanner1<false>(this, store_pairs, sel);
202
221
  }
203
- return nullptr;
204
222
  }
205
223
 
206
224
  void IndexIVFFlat::reconstruct_from_offset(
@@ -223,18 +241,17 @@ IndexIVFFlatDedup::IndexIVFFlatDedup(
223
241
 
224
242
  void IndexIVFFlatDedup::train(idx_t n, const float* x) {
225
243
  std::unordered_map<uint64_t, idx_t> map;
226
- float* x2 = new float[n * d];
227
- ScopeDeleter<float> del(x2);
244
+ std::unique_ptr<float[]> x2(new float[n * d]);
228
245
 
229
246
  int64_t n2 = 0;
230
247
  for (int64_t i = 0; i < n; i++) {
231
248
  uint64_t hash = hash_bytes((uint8_t*)(x + i * d), code_size);
232
249
  if (map.count(hash) &&
233
- !memcmp(x2 + map[hash] * d, x + i * d, code_size)) {
250
+ !memcmp(x2.get() + map[hash] * d, x + i * d, code_size)) {
234
251
  // is duplicate, skip
235
252
  } else {
236
253
  map[hash] = n2;
237
- memcpy(x2 + n2 * d, x + i * d, code_size);
254
+ memcpy(x2.get() + n2 * d, x + i * d, code_size);
238
255
  n2++;
239
256
  }
240
257
  }
@@ -245,7 +262,7 @@ void IndexIVFFlatDedup::train(idx_t n, const float* x) {
245
262
  n2,
246
263
  n);
247
264
  }
248
- IndexIVFFlat::train(n2, x2);
265
+ IndexIVFFlat::train(n2, x2.get());
249
266
  }
250
267
 
251
268
  void IndexIVFFlatDedup::add_with_ids(
@@ -256,9 +273,8 @@ void IndexIVFFlatDedup::add_with_ids(
256
273
  assert(invlists);
257
274
  FAISS_THROW_IF_NOT_MSG(
258
275
  direct_map.no(), "IVFFlatDedup not implemented with direct_map");
259
- int64_t* idx = new int64_t[na];
260
- ScopeDeleter<int64_t> del(idx);
261
- quantizer->assign(na, x, idx);
276
+ std::unique_ptr<int64_t[]> idx(new int64_t[na]);
277
+ quantizer->assign(na, x, idx.get());
262
278
 
263
279
  int64_t n_add = 0, n_dup = 0;
264
280
 
@@ -450,7 +466,8 @@ void IndexIVFFlatDedup::range_search(
450
466
  idx_t,
451
467
  const float*,
452
468
  float,
453
- RangeSearchResult*) const {
469
+ RangeSearchResult*,
470
+ const SearchParameters*) const {
454
471
  FAISS_THROW_MSG("not implemented");
455
472
  }
456
473
 
@@ -42,7 +42,8 @@ struct IndexIVFFlat : IndexIVF {
42
42
  bool include_listnos = false) const override;
43
43
 
44
44
  InvertedListScanner* get_InvertedListScanner(
45
- bool store_pairs) const override;
45
+ bool store_pairs,
46
+ const IDSelector* sel) const override;
46
47
 
47
48
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
48
49
  const override;
@@ -89,7 +90,8 @@ struct IndexIVFFlatDedup : IndexIVFFlat {
89
90
  idx_t n,
90
91
  const float* x,
91
92
  float radius,
92
- RangeSearchResult* result) const override;
93
+ RangeSearchResult* result,
94
+ const SearchParameters* params = nullptr) const override;
93
95
 
94
96
  /// not implemented
95
97
  void update_vectors(int nv, const idx_t* idx, const float* v) override;