faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,247 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexIDMap.h>
11
+
12
+ #include <stdint.h>
13
+ #include <cinttypes>
14
+ #include <cstdio>
15
+ #include <limits>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/IDSelector.h>
20
+ #include <faiss/utils/Heap.h>
21
+ #include <faiss/utils/WorkerThread.h>
22
+
23
+ namespace faiss {
24
+
25
+ /*****************************************************
26
+ * IndexIDMap implementation
27
+ *******************************************************/
28
+
29
+ template <typename IndexT>
30
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
+ : index(index), own_fields(false) {
32
+ FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
+ this->is_trained = index->is_trained;
34
+ this->metric_type = index->metric_type;
35
+ this->verbose = index->verbose;
36
+ this->d = index->d;
37
+ }
38
+
39
+ template <typename IndexT>
40
+ void IndexIDMapTemplate<IndexT>::add(
41
+ idx_t,
42
+ const typename IndexT::component_t*) {
43
+ FAISS_THROW_MSG(
44
+ "add does not make sense with IndexIDMap, "
45
+ "use add_with_ids");
46
+ }
47
+
48
+ template <typename IndexT>
49
+ void IndexIDMapTemplate<IndexT>::train(
50
+ idx_t n,
51
+ const typename IndexT::component_t* x) {
52
+ index->train(n, x);
53
+ this->is_trained = index->is_trained;
54
+ }
55
+
56
+ template <typename IndexT>
57
+ void IndexIDMapTemplate<IndexT>::reset() {
58
+ index->reset();
59
+ id_map.clear();
60
+ this->ntotal = 0;
61
+ }
62
+
63
+ template <typename IndexT>
64
+ void IndexIDMapTemplate<IndexT>::add_with_ids(
65
+ idx_t n,
66
+ const typename IndexT::component_t* x,
67
+ const typename IndexT::idx_t* xids) {
68
+ index->add(n, x);
69
+ for (idx_t i = 0; i < n; i++)
70
+ id_map.push_back(xids[i]);
71
+ this->ntotal = index->ntotal;
72
+ }
73
+
74
+ template <typename IndexT>
75
+ void IndexIDMapTemplate<IndexT>::search(
76
+ idx_t n,
77
+ const typename IndexT::component_t* x,
78
+ idx_t k,
79
+ typename IndexT::distance_t* distances,
80
+ typename IndexT::idx_t* labels,
81
+ const SearchParameters* params) const {
82
+ FAISS_THROW_IF_NOT_MSG(
83
+ !params, "search params not supported for this index");
84
+ index->search(n, x, k, distances, labels);
85
+ idx_t* li = labels;
86
+ #pragma omp parallel for
87
+ for (idx_t i = 0; i < n * k; i++) {
88
+ li[i] = li[i] < 0 ? li[i] : id_map[li[i]];
89
+ }
90
+ }
91
+
92
+ template <typename IndexT>
93
+ void IndexIDMapTemplate<IndexT>::range_search(
94
+ typename IndexT::idx_t n,
95
+ const typename IndexT::component_t* x,
96
+ typename IndexT::distance_t radius,
97
+ RangeSearchResult* result,
98
+ const SearchParameters* params) const {
99
+ FAISS_THROW_IF_NOT_MSG(
100
+ !params, "search params not supported for this index");
101
+ index->range_search(n, x, radius, result);
102
+ #pragma omp parallel for
103
+ for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
+ result->labels[i] = result->labels[i] < 0 ? result->labels[i]
105
+ : id_map[result->labels[i]];
106
+ }
107
+ }
108
+
109
+ namespace {
110
+
111
+ struct IDTranslatedSelector : IDSelector {
112
+ const std::vector<int64_t>& id_map;
113
+ const IDSelector& sel;
114
+ IDTranslatedSelector(
115
+ const std::vector<int64_t>& id_map,
116
+ const IDSelector& sel)
117
+ : id_map(id_map), sel(sel) {}
118
+ bool is_member(idx_t id) const override {
119
+ return sel.is_member(id_map[id]);
120
+ }
121
+ };
122
+
123
+ } // namespace
124
+
125
+ template <typename IndexT>
126
+ size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
+ // remove in sub-index first
128
+ IDTranslatedSelector sel2(id_map, sel);
129
+ size_t nremove = index->remove_ids(sel2);
130
+
131
+ int64_t j = 0;
132
+ for (idx_t i = 0; i < this->ntotal; i++) {
133
+ if (sel.is_member(id_map[i])) {
134
+ // remove
135
+ } else {
136
+ id_map[j] = id_map[i];
137
+ j++;
138
+ }
139
+ }
140
+ FAISS_ASSERT(j == index->ntotal);
141
+ this->ntotal = j;
142
+ id_map.resize(this->ntotal);
143
+ return nremove;
144
+ }
145
+
146
+ template <typename IndexT>
147
+ void IndexIDMapTemplate<IndexT>::check_compatible_for_merge(
148
+ const IndexT& otherIndex) const {
149
+ auto other = dynamic_cast<const IndexIDMapTemplate<IndexT>*>(&otherIndex);
150
+ FAISS_THROW_IF_NOT(other);
151
+ index->check_compatible_for_merge(*other->index);
152
+ }
153
+
154
+ template <typename IndexT>
155
+ void IndexIDMapTemplate<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
156
+ check_compatible_for_merge(otherIndex);
157
+ auto other = static_cast<IndexIDMapTemplate<IndexT>*>(&otherIndex);
158
+ index->merge_from(*other->index);
159
+ for (size_t i = 0; i < other->id_map.size(); i++) {
160
+ id_map.push_back(other->id_map[i] + add_id);
161
+ }
162
+ other->id_map.resize(0);
163
+ this->ntotal = index->ntotal;
164
+ other->ntotal = 0;
165
+ }
166
+
167
+ template <typename IndexT>
168
+ IndexIDMapTemplate<IndexT>::~IndexIDMapTemplate() {
169
+ if (own_fields)
170
+ delete index;
171
+ }
172
+
173
+ /*****************************************************
174
+ * IndexIDMap2 implementation
175
+ *******************************************************/
176
+
177
+ template <typename IndexT>
178
+ IndexIDMap2Template<IndexT>::IndexIDMap2Template(IndexT* index)
179
+ : IndexIDMapTemplate<IndexT>(index) {}
180
+
181
+ template <typename IndexT>
182
+ void IndexIDMap2Template<IndexT>::add_with_ids(
183
+ idx_t n,
184
+ const typename IndexT::component_t* x,
185
+ const typename IndexT::idx_t* xids) {
186
+ size_t prev_ntotal = this->ntotal;
187
+ IndexIDMapTemplate<IndexT>::add_with_ids(n, x, xids);
188
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
189
+ rev_map[this->id_map[i]] = i;
190
+ }
191
+ }
192
+
193
+ template <typename IndexT>
194
+ void IndexIDMap2Template<IndexT>::check_consistency() const {
195
+ FAISS_THROW_IF_NOT(rev_map.size() == this->id_map.size());
196
+ FAISS_THROW_IF_NOT(this->id_map.size() == this->ntotal);
197
+ for (size_t i = 0; i < this->ntotal; i++) {
198
+ idx_t ii = rev_map.at(this->id_map[i]);
199
+ FAISS_THROW_IF_NOT(ii == i);
200
+ }
201
+ }
202
+
203
+ template <typename IndexT>
204
+ void IndexIDMap2Template<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
205
+ size_t prev_ntotal = this->ntotal;
206
+ IndexIDMapTemplate<IndexT>::merge_from(otherIndex, add_id);
207
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
208
+ rev_map[this->id_map[i]] = i;
209
+ }
210
+ static_cast<IndexIDMap2Template<IndexT>&>(otherIndex).rev_map.clear();
211
+ }
212
+
213
+ template <typename IndexT>
214
+ void IndexIDMap2Template<IndexT>::construct_rev_map() {
215
+ rev_map.clear();
216
+ for (size_t i = 0; i < this->ntotal; i++) {
217
+ rev_map[this->id_map[i]] = i;
218
+ }
219
+ }
220
+
221
+ template <typename IndexT>
222
+ size_t IndexIDMap2Template<IndexT>::remove_ids(const IDSelector& sel) {
223
+ // This is quite inefficient
224
+ size_t nremove = IndexIDMapTemplate<IndexT>::remove_ids(sel);
225
+ construct_rev_map();
226
+ return nremove;
227
+ }
228
+
229
+ template <typename IndexT>
230
+ void IndexIDMap2Template<IndexT>::reconstruct(
231
+ idx_t key,
232
+ typename IndexT::component_t* recons) const {
233
+ try {
234
+ this->index->reconstruct(rev_map.at(key), recons);
235
+ } catch (const std::out_of_range& e) {
236
+ FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
+ }
238
+ }
239
+
240
+ // explicit template instantiations
241
+
242
+ template struct IndexIDMapTemplate<Index>;
243
+ template struct IndexIDMapTemplate<IndexBinary>;
244
+ template struct IndexIDMap2Template<Index>;
245
+ template struct IndexIDMap2Template<IndexBinary>;
246
+
247
+ } // namespace faiss
@@ -0,0 +1,107 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/IndexBinary.h>
12
+
13
+ #include <unordered_map>
14
+ #include <vector>
15
+
16
+ namespace faiss {
17
+
18
+ /** Index that translates search results to ids */
19
+ template <typename IndexT>
20
+ struct IndexIDMapTemplate : IndexT {
21
+ using idx_t = typename IndexT::idx_t;
22
+ using component_t = typename IndexT::component_t;
23
+ using distance_t = typename IndexT::distance_t;
24
+
25
+ IndexT* index; ///! the sub-index
26
+ bool own_fields; ///! whether pointers are deleted in destructo
27
+ std::vector<idx_t> id_map;
28
+
29
+ explicit IndexIDMapTemplate(IndexT* index);
30
+
31
+ /// @param xids if non-null, ids to store for the vectors (size n)
32
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
33
+ override;
34
+
35
+ /// this will fail. Use add_with_ids
36
+ void add(idx_t n, const component_t* x) override;
37
+
38
+ void search(
39
+ idx_t n,
40
+ const component_t* x,
41
+ idx_t k,
42
+ distance_t* distances,
43
+ idx_t* labels,
44
+ const SearchParameters* params = nullptr) const override;
45
+
46
+ void train(idx_t n, const component_t* x) override;
47
+
48
+ void reset() override;
49
+
50
+ /// remove ids adapted to IndexFlat
51
+ size_t remove_ids(const IDSelector& sel) override;
52
+
53
+ void range_search(
54
+ idx_t n,
55
+ const component_t* x,
56
+ distance_t radius,
57
+ RangeSearchResult* result,
58
+ const SearchParameters* params = nullptr) const override;
59
+
60
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
61
+ void check_compatible_for_merge(const IndexT& otherIndex) const override;
62
+
63
+ ~IndexIDMapTemplate() override;
64
+ IndexIDMapTemplate() {
65
+ own_fields = false;
66
+ index = nullptr;
67
+ }
68
+ };
69
+
70
+ using IndexIDMap = IndexIDMapTemplate<Index>;
71
+ using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
72
+
73
+ /** same as IndexIDMap but also provides an efficient reconstruction
74
+ * implementation via a 2-way index */
75
+ template <typename IndexT>
76
+ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
77
+ using idx_t = typename IndexT::idx_t;
78
+ using component_t = typename IndexT::component_t;
79
+ using distance_t = typename IndexT::distance_t;
80
+
81
+ std::unordered_map<idx_t, idx_t> rev_map;
82
+
83
+ explicit IndexIDMap2Template(IndexT* index);
84
+
85
+ /// make the rev_map from scratch
86
+ void construct_rev_map();
87
+
88
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
89
+ override;
90
+
91
+ size_t remove_ids(const IDSelector& sel) override;
92
+
93
+ void reconstruct(idx_t key, component_t* recons) const override;
94
+
95
+ /// check that the rev_map and the id_map are in sync
96
+ void check_consistency() const;
97
+
98
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
99
+
100
+ ~IndexIDMap2Template() override {}
101
+ IndexIDMap2Template() {}
102
+ };
103
+
104
+ using IndexIDMap2 = IndexIDMap2Template<Index>;
105
+ using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
106
+
107
+ } // namespace faiss
@@ -23,6 +23,7 @@
23
23
  #include <faiss/IndexFlat.h>
24
24
  #include <faiss/impl/AuxIndexStructures.h>
25
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
26
27
 
27
28
  namespace faiss {
28
29
 
@@ -303,14 +304,20 @@ void IndexIVF::search(
303
304
  const float* x,
304
305
  idx_t k,
305
306
  float* distances,
306
- idx_t* labels) const {
307
+ idx_t* labels,
308
+ const SearchParameters* params_in) const {
307
309
  FAISS_THROW_IF_NOT(k > 0);
308
-
309
- const size_t nprobe = std::min(nlist, this->nprobe);
310
+ const IVFSearchParameters* params = nullptr;
311
+ if (params_in) {
312
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
313
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
314
+ }
315
+ const size_t nprobe =
316
+ std::min(nlist, params ? params->nprobe : this->nprobe);
310
317
  FAISS_THROW_IF_NOT(nprobe > 0);
311
318
 
312
319
  // search function for a subset of queries
313
- auto sub_search_func = [this, k, nprobe](
320
+ auto sub_search_func = [this, k, nprobe, params](
314
321
  idx_t n,
315
322
  const float* x,
316
323
  float* distances,
@@ -320,7 +327,13 @@ void IndexIVF::search(
320
327
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
321
328
 
322
329
  double t0 = getmillisecs();
323
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
330
+ quantizer->search(
331
+ n,
332
+ x,
333
+ nprobe,
334
+ coarse_dis.get(),
335
+ idx.get(),
336
+ params ? params->quantizer_params : nullptr);
324
337
 
325
338
  double t1 = getmillisecs();
326
339
  invlists->prefetch_lists(idx.get(), n * nprobe);
@@ -334,7 +347,7 @@ void IndexIVF::search(
334
347
  distances,
335
348
  labels,
336
349
  false,
337
- nullptr,
350
+ params,
338
351
  ivf_stats);
339
352
  double t2 = getmillisecs();
340
353
  ivf_stats->quantization_time += t1 - t0;
@@ -400,6 +413,19 @@ void IndexIVF::search_preassigned(
400
413
  FAISS_THROW_IF_NOT(nprobe > 0);
401
414
 
402
415
  idx_t max_codes = params ? params->max_codes : this->max_codes;
416
+ IDSelector* sel = params ? params->sel : nullptr;
417
+ const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
418
+ if (selr) {
419
+ if (selr->assume_sorted) {
420
+ sel = nullptr; // use special IDSelectorRange processing
421
+ } else {
422
+ selr = nullptr; // use generic processing
423
+ }
424
+ }
425
+
426
+ FAISS_THROW_IF_NOT_MSG(
427
+ !(sel && store_pairs),
428
+ "selector and store_pairs cannot be combined");
403
429
 
404
430
  size_t nlistv = 0, ndis = 0, nheap = 0;
405
431
 
@@ -421,7 +447,8 @@ void IndexIVF::search_preassigned(
421
447
 
422
448
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
423
449
  {
424
- InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
450
+ InvertedListScanner* scanner =
451
+ get_InvertedListScanner(store_pairs, sel);
425
452
  ScopeDeleter1<InvertedListScanner> del(scanner);
426
453
 
427
454
  /*****************************************************
@@ -492,6 +519,7 @@ void IndexIVF::search_preassigned(
492
519
 
493
520
  try {
494
521
  InvertedLists::ScopedCodes scodes(invlists, key);
522
+ const uint8_t* codes = scodes.get();
495
523
 
496
524
  std::unique_ptr<InvertedLists::ScopedIds> sids;
497
525
  const Index::idx_t* ids = nullptr;
@@ -501,8 +529,20 @@ void IndexIVF::search_preassigned(
501
529
  ids = sids->get();
502
530
  }
503
531
 
532
+ if (selr) { // IDSelectorRange
533
+ // restrict search to a section of the inverted list
534
+ size_t jmin, jmax;
535
+ selr->find_sorted_ids_bounds(list_size, ids, &jmin, &jmax);
536
+ list_size = jmax - jmin;
537
+ if (list_size == 0) {
538
+ return (size_t)0;
539
+ }
540
+ codes += jmin * code_size;
541
+ ids += jmin;
542
+ }
543
+
504
544
  nheap += scanner->scan_codes(
505
- list_size, scodes.get(), ids, simi, idxi, k);
545
+ list_size, codes, ids, simi, idxi, k);
506
546
 
507
547
  } catch (const std::exception& e) {
508
548
  std::lock_guard<std::mutex> lock(exception_mutex);
@@ -651,13 +691,23 @@ void IndexIVF::range_search(
651
691
  idx_t nx,
652
692
  const float* x,
653
693
  float radius,
654
- RangeSearchResult* result) const {
655
- const size_t nprobe = std::min(nlist, this->nprobe);
694
+ RangeSearchResult* result,
695
+ const SearchParameters* params_in) const {
696
+ const IVFSearchParameters* params = nullptr;
697
+ const SearchParameters* quantizer_params = nullptr;
698
+ if (params_in) {
699
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
700
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
701
+ quantizer_params = params->quantizer_params;
702
+ }
703
+ const size_t nprobe =
704
+ std::min(nlist, params ? params->nprobe : this->nprobe);
656
705
  std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
657
706
  std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
658
707
 
659
708
  double t0 = getmillisecs();
660
- quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
709
+ quantizer->search(
710
+ nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
661
711
  indexIVF_stats.quantization_time += getmillisecs() - t0;
662
712
 
663
713
  t0 = getmillisecs();
@@ -671,7 +721,7 @@ void IndexIVF::range_search(
671
721
  coarse_dis.get(),
672
722
  result,
673
723
  false,
674
- nullptr,
724
+ params,
675
725
  &indexIVF_stats);
676
726
 
677
727
  indexIVF_stats.search_time += getmillisecs() - t0;
@@ -689,7 +739,10 @@ void IndexIVF::range_search_preassigned(
689
739
  IndexIVFStats* stats) const {
690
740
  idx_t nprobe = params ? params->nprobe : this->nprobe;
691
741
  nprobe = std::min((idx_t)nlist, nprobe);
742
+ FAISS_THROW_IF_NOT(nprobe > 0);
743
+
692
744
  idx_t max_codes = params ? params->max_codes : this->max_codes;
745
+ IDSelector* sel = params ? params->sel : nullptr;
693
746
 
694
747
  size_t nlistv = 0, ndis = 0;
695
748
 
@@ -711,7 +764,7 @@ void IndexIVF::range_search_preassigned(
711
764
  {
712
765
  RangeSearchPartialResult pres(result);
713
766
  std::unique_ptr<InvertedListScanner> scanner(
714
- get_InvertedListScanner(store_pairs));
767
+ get_InvertedListScanner(store_pairs, sel));
715
768
  FAISS_THROW_IF_NOT(scanner.get());
716
769
  all_pres[omp_get_thread_num()] = &pres;
717
770
 
@@ -774,7 +827,6 @@ void IndexIVF::range_search_preassigned(
774
827
  }
775
828
  }
776
829
  } else if (parallel_mode == 2) {
777
- std::vector<RangeQueryResult*> all_qres(nx);
778
830
  RangeQueryResult* qres = nullptr;
779
831
 
780
832
  #pragma omp for schedule(dynamic)
@@ -782,7 +834,6 @@ void IndexIVF::range_search_preassigned(
782
834
  idx_t i = iik / (idx_t)nprobe;
783
835
  idx_t ik = iik % (idx_t)nprobe;
784
836
  if (qres == nullptr || qres->qno != i) {
785
- FAISS_ASSERT(!qres || i > qres->qno);
786
837
  qres = &pres.new_result(i);
787
838
  scanner->set_query(x + i * d);
788
839
  }
@@ -818,7 +869,8 @@ void IndexIVF::range_search_preassigned(
818
869
  }
819
870
 
820
871
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
821
- bool /*store_pairs*/) const {
872
+ bool /*store_pairs*/,
873
+ const IDSelector* /* sel */) const {
822
874
  return nullptr;
823
875
  }
824
876
 
@@ -846,6 +898,21 @@ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
846
898
  }
847
899
  }
848
900
 
901
+ bool IndexIVF::check_ids_sorted() const {
902
+ size_t nflip = 0;
903
+
904
+ for (size_t i = 0; i < nlist; i++) {
905
+ size_t list_size = invlists->list_size(i);
906
+ InvertedLists::ScopedIds ids(invlists, i);
907
+ for (size_t j = 0; j + 1 < list_size; j++) {
908
+ if (ids[j + 1] < ids[j]) {
909
+ nflip++;
910
+ }
911
+ }
912
+ }
913
+ return nflip == 0;
914
+ }
915
+
849
916
  /* standalone codec interface */
850
917
  size_t IndexIVF::sa_code_size() const {
851
918
  size_t coarse_size = coarse_code_size();
@@ -865,10 +932,15 @@ void IndexIVF::search_and_reconstruct(
865
932
  idx_t k,
866
933
  float* distances,
867
934
  idx_t* labels,
868
- float* recons) const {
869
- FAISS_THROW_IF_NOT(k > 0);
870
-
871
- const size_t nprobe = std::min(nlist, this->nprobe);
935
+ float* recons,
936
+ const SearchParameters* params_in) const {
937
+ const IVFSearchParameters* params = nullptr;
938
+ if (params_in) {
939
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
940
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
941
+ }
942
+ const size_t nprobe =
943
+ std::min(nlist, params ? params->nprobe : this->nprobe);
872
944
  FAISS_THROW_IF_NOT(nprobe > 0);
873
945
 
874
946
  idx_t* idx = new idx_t[n * nprobe];
@@ -890,7 +962,8 @@ void IndexIVF::search_and_reconstruct(
890
962
  coarse_dis,
891
963
  distances,
892
964
  labels,
893
- true /* store_pairs */);
965
+ true /* store_pairs */,
966
+ params);
894
967
  for (idx_t i = 0; i < n; ++i) {
895
968
  for (idx_t j = 0; j < k; ++j) {
896
969
  idx_t ij = i * k + j;
@@ -976,26 +1049,41 @@ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
976
1049
  // does nothing by default
977
1050
  }
978
1051
 
979
- void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
1052
+ bool check_compatible_for_merge_expensive_check = true;
1053
+
1054
+ void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
980
1055
  // minimal sanity checks
981
- FAISS_THROW_IF_NOT(other.d == d);
982
- FAISS_THROW_IF_NOT(other.nlist == nlist);
983
- FAISS_THROW_IF_NOT(other.code_size == code_size);
1056
+ const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
1057
+ FAISS_THROW_IF_NOT(other);
1058
+ FAISS_THROW_IF_NOT(other->d == d);
1059
+ FAISS_THROW_IF_NOT(other->nlist == nlist);
1060
+ FAISS_THROW_IF_NOT(quantizer->ntotal == other->quantizer->ntotal);
1061
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
984
1062
  FAISS_THROW_IF_NOT_MSG(
985
- typeid(*this) == typeid(other),
1063
+ typeid(*this) == typeid(*other),
986
1064
  "can only merge indexes of the same type");
987
1065
  FAISS_THROW_IF_NOT_MSG(
988
- this->direct_map.no() && other.direct_map.no(),
1066
+ this->direct_map.no() && other->direct_map.no(),
989
1067
  "merge direct_map not implemented");
990
- }
991
1068
 
992
- void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
993
- check_compatible_for_merge(other);
1069
+ if (check_compatible_for_merge_expensive_check) {
1070
+ std::vector<float> v(d), v2(d);
1071
+ for (size_t i = 0; i < nlist; i++) {
1072
+ quantizer->reconstruct(i, v.data());
1073
+ other->quantizer->reconstruct(i, v2.data());
1074
+ FAISS_THROW_IF_NOT_MSG(
1075
+ v == v2, "coarse quantizers should be the same");
1076
+ }
1077
+ }
1078
+ }
994
1079
 
995
- invlists->merge_from(other.invlists, add_id);
1080
+ void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
1081
+ check_compatible_for_merge(otherIndex);
1082
+ IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
1083
+ invlists->merge_from(other->invlists, add_id);
996
1084
 
997
- ntotal += other.ntotal;
998
- other.ntotal = 0;
1085
+ ntotal += other->ntotal;
1086
+ other->ntotal = 0;
999
1087
  }
1000
1088
 
1001
1089
  void IndexIVF::replace_invlists(InvertedLists* il, bool own) {