faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,247 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexIDMap.h>
11
+
12
+ #include <stdint.h>
13
+ #include <cinttypes>
14
+ #include <cstdio>
15
+ #include <limits>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/IDSelector.h>
20
+ #include <faiss/utils/Heap.h>
21
+ #include <faiss/utils/WorkerThread.h>
22
+
23
+ namespace faiss {
24
+
25
+ /*****************************************************
26
+ * IndexIDMap implementation
27
+ *******************************************************/
28
+
29
+ template <typename IndexT>
30
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
+ : index(index), own_fields(false) {
32
+ FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
+ this->is_trained = index->is_trained;
34
+ this->metric_type = index->metric_type;
35
+ this->verbose = index->verbose;
36
+ this->d = index->d;
37
+ }
38
+
39
+ template <typename IndexT>
40
+ void IndexIDMapTemplate<IndexT>::add(
41
+ idx_t,
42
+ const typename IndexT::component_t*) {
43
+ FAISS_THROW_MSG(
44
+ "add does not make sense with IndexIDMap, "
45
+ "use add_with_ids");
46
+ }
47
+
48
+ template <typename IndexT>
49
+ void IndexIDMapTemplate<IndexT>::train(
50
+ idx_t n,
51
+ const typename IndexT::component_t* x) {
52
+ index->train(n, x);
53
+ this->is_trained = index->is_trained;
54
+ }
55
+
56
+ template <typename IndexT>
57
+ void IndexIDMapTemplate<IndexT>::reset() {
58
+ index->reset();
59
+ id_map.clear();
60
+ this->ntotal = 0;
61
+ }
62
+
63
+ template <typename IndexT>
64
+ void IndexIDMapTemplate<IndexT>::add_with_ids(
65
+ idx_t n,
66
+ const typename IndexT::component_t* x,
67
+ const typename IndexT::idx_t* xids) {
68
+ index->add(n, x);
69
+ for (idx_t i = 0; i < n; i++)
70
+ id_map.push_back(xids[i]);
71
+ this->ntotal = index->ntotal;
72
+ }
73
+
74
+ template <typename IndexT>
75
+ void IndexIDMapTemplate<IndexT>::search(
76
+ idx_t n,
77
+ const typename IndexT::component_t* x,
78
+ idx_t k,
79
+ typename IndexT::distance_t* distances,
80
+ typename IndexT::idx_t* labels,
81
+ const SearchParameters* params) const {
82
+ FAISS_THROW_IF_NOT_MSG(
83
+ !params, "search params not supported for this index");
84
+ index->search(n, x, k, distances, labels);
85
+ idx_t* li = labels;
86
+ #pragma omp parallel for
87
+ for (idx_t i = 0; i < n * k; i++) {
88
+ li[i] = li[i] < 0 ? li[i] : id_map[li[i]];
89
+ }
90
+ }
91
+
92
+ template <typename IndexT>
93
+ void IndexIDMapTemplate<IndexT>::range_search(
94
+ typename IndexT::idx_t n,
95
+ const typename IndexT::component_t* x,
96
+ typename IndexT::distance_t radius,
97
+ RangeSearchResult* result,
98
+ const SearchParameters* params) const {
99
+ FAISS_THROW_IF_NOT_MSG(
100
+ !params, "search params not supported for this index");
101
+ index->range_search(n, x, radius, result);
102
+ #pragma omp parallel for
103
+ for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
+ result->labels[i] = result->labels[i] < 0 ? result->labels[i]
105
+ : id_map[result->labels[i]];
106
+ }
107
+ }
108
+
109
+ namespace {
110
+
111
+ struct IDTranslatedSelector : IDSelector {
112
+ const std::vector<int64_t>& id_map;
113
+ const IDSelector& sel;
114
+ IDTranslatedSelector(
115
+ const std::vector<int64_t>& id_map,
116
+ const IDSelector& sel)
117
+ : id_map(id_map), sel(sel) {}
118
+ bool is_member(idx_t id) const override {
119
+ return sel.is_member(id_map[id]);
120
+ }
121
+ };
122
+
123
+ } // namespace
124
+
125
+ template <typename IndexT>
126
+ size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
+ // remove in sub-index first
128
+ IDTranslatedSelector sel2(id_map, sel);
129
+ size_t nremove = index->remove_ids(sel2);
130
+
131
+ int64_t j = 0;
132
+ for (idx_t i = 0; i < this->ntotal; i++) {
133
+ if (sel.is_member(id_map[i])) {
134
+ // remove
135
+ } else {
136
+ id_map[j] = id_map[i];
137
+ j++;
138
+ }
139
+ }
140
+ FAISS_ASSERT(j == index->ntotal);
141
+ this->ntotal = j;
142
+ id_map.resize(this->ntotal);
143
+ return nremove;
144
+ }
145
+
146
+ template <typename IndexT>
147
+ void IndexIDMapTemplate<IndexT>::check_compatible_for_merge(
148
+ const IndexT& otherIndex) const {
149
+ auto other = dynamic_cast<const IndexIDMapTemplate<IndexT>*>(&otherIndex);
150
+ FAISS_THROW_IF_NOT(other);
151
+ index->check_compatible_for_merge(*other->index);
152
+ }
153
+
154
+ template <typename IndexT>
155
+ void IndexIDMapTemplate<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
156
+ check_compatible_for_merge(otherIndex);
157
+ auto other = static_cast<IndexIDMapTemplate<IndexT>*>(&otherIndex);
158
+ index->merge_from(*other->index);
159
+ for (size_t i = 0; i < other->id_map.size(); i++) {
160
+ id_map.push_back(other->id_map[i] + add_id);
161
+ }
162
+ other->id_map.resize(0);
163
+ this->ntotal = index->ntotal;
164
+ other->ntotal = 0;
165
+ }
166
+
167
+ template <typename IndexT>
168
+ IndexIDMapTemplate<IndexT>::~IndexIDMapTemplate() {
169
+ if (own_fields)
170
+ delete index;
171
+ }
172
+
173
+ /*****************************************************
174
+ * IndexIDMap2 implementation
175
+ *******************************************************/
176
+
177
+ template <typename IndexT>
178
+ IndexIDMap2Template<IndexT>::IndexIDMap2Template(IndexT* index)
179
+ : IndexIDMapTemplate<IndexT>(index) {}
180
+
181
+ template <typename IndexT>
182
+ void IndexIDMap2Template<IndexT>::add_with_ids(
183
+ idx_t n,
184
+ const typename IndexT::component_t* x,
185
+ const typename IndexT::idx_t* xids) {
186
+ size_t prev_ntotal = this->ntotal;
187
+ IndexIDMapTemplate<IndexT>::add_with_ids(n, x, xids);
188
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
189
+ rev_map[this->id_map[i]] = i;
190
+ }
191
+ }
192
+
193
+ template <typename IndexT>
194
+ void IndexIDMap2Template<IndexT>::check_consistency() const {
195
+ FAISS_THROW_IF_NOT(rev_map.size() == this->id_map.size());
196
+ FAISS_THROW_IF_NOT(this->id_map.size() == this->ntotal);
197
+ for (size_t i = 0; i < this->ntotal; i++) {
198
+ idx_t ii = rev_map.at(this->id_map[i]);
199
+ FAISS_THROW_IF_NOT(ii == i);
200
+ }
201
+ }
202
+
203
+ template <typename IndexT>
204
+ void IndexIDMap2Template<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
205
+ size_t prev_ntotal = this->ntotal;
206
+ IndexIDMapTemplate<IndexT>::merge_from(otherIndex, add_id);
207
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
208
+ rev_map[this->id_map[i]] = i;
209
+ }
210
+ static_cast<IndexIDMap2Template<IndexT>&>(otherIndex).rev_map.clear();
211
+ }
212
+
213
+ template <typename IndexT>
214
+ void IndexIDMap2Template<IndexT>::construct_rev_map() {
215
+ rev_map.clear();
216
+ for (size_t i = 0; i < this->ntotal; i++) {
217
+ rev_map[this->id_map[i]] = i;
218
+ }
219
+ }
220
+
221
+ template <typename IndexT>
222
+ size_t IndexIDMap2Template<IndexT>::remove_ids(const IDSelector& sel) {
223
+ // This is quite inefficient
224
+ size_t nremove = IndexIDMapTemplate<IndexT>::remove_ids(sel);
225
+ construct_rev_map();
226
+ return nremove;
227
+ }
228
+
229
+ template <typename IndexT>
230
+ void IndexIDMap2Template<IndexT>::reconstruct(
231
+ idx_t key,
232
+ typename IndexT::component_t* recons) const {
233
+ try {
234
+ this->index->reconstruct(rev_map.at(key), recons);
235
+ } catch (const std::out_of_range& e) {
236
+ FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
+ }
238
+ }
239
+
240
+ // explicit template instantiations
241
+
242
+ template struct IndexIDMapTemplate<Index>;
243
+ template struct IndexIDMapTemplate<IndexBinary>;
244
+ template struct IndexIDMap2Template<Index>;
245
+ template struct IndexIDMap2Template<IndexBinary>;
246
+
247
+ } // namespace faiss
@@ -0,0 +1,107 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/IndexBinary.h>
12
+
13
+ #include <unordered_map>
14
+ #include <vector>
15
+
16
+ namespace faiss {
17
+
18
+ /** Index that translates search results to ids */
19
+ template <typename IndexT>
20
+ struct IndexIDMapTemplate : IndexT {
21
+ using idx_t = typename IndexT::idx_t;
22
+ using component_t = typename IndexT::component_t;
23
+ using distance_t = typename IndexT::distance_t;
24
+
25
+ IndexT* index; ///! the sub-index
26
+ bool own_fields; ///! whether pointers are deleted in destructo
27
+ std::vector<idx_t> id_map;
28
+
29
+ explicit IndexIDMapTemplate(IndexT* index);
30
+
31
+ /// @param xids if non-null, ids to store for the vectors (size n)
32
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
33
+ override;
34
+
35
+ /// this will fail. Use add_with_ids
36
+ void add(idx_t n, const component_t* x) override;
37
+
38
+ void search(
39
+ idx_t n,
40
+ const component_t* x,
41
+ idx_t k,
42
+ distance_t* distances,
43
+ idx_t* labels,
44
+ const SearchParameters* params = nullptr) const override;
45
+
46
+ void train(idx_t n, const component_t* x) override;
47
+
48
+ void reset() override;
49
+
50
+ /// remove ids adapted to IndexFlat
51
+ size_t remove_ids(const IDSelector& sel) override;
52
+
53
+ void range_search(
54
+ idx_t n,
55
+ const component_t* x,
56
+ distance_t radius,
57
+ RangeSearchResult* result,
58
+ const SearchParameters* params = nullptr) const override;
59
+
60
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
61
+ void check_compatible_for_merge(const IndexT& otherIndex) const override;
62
+
63
+ ~IndexIDMapTemplate() override;
64
+ IndexIDMapTemplate() {
65
+ own_fields = false;
66
+ index = nullptr;
67
+ }
68
+ };
69
+
70
+ using IndexIDMap = IndexIDMapTemplate<Index>;
71
+ using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
72
+
73
+ /** same as IndexIDMap but also provides an efficient reconstruction
74
+ * implementation via a 2-way index */
75
+ template <typename IndexT>
76
+ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
77
+ using idx_t = typename IndexT::idx_t;
78
+ using component_t = typename IndexT::component_t;
79
+ using distance_t = typename IndexT::distance_t;
80
+
81
+ std::unordered_map<idx_t, idx_t> rev_map;
82
+
83
+ explicit IndexIDMap2Template(IndexT* index);
84
+
85
+ /// make the rev_map from scratch
86
+ void construct_rev_map();
87
+
88
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
89
+ override;
90
+
91
+ size_t remove_ids(const IDSelector& sel) override;
92
+
93
+ void reconstruct(idx_t key, component_t* recons) const override;
94
+
95
+ /// check that the rev_map and the id_map are in sync
96
+ void check_consistency() const;
97
+
98
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
99
+
100
+ ~IndexIDMap2Template() override {}
101
+ IndexIDMap2Template() {}
102
+ };
103
+
104
+ using IndexIDMap2 = IndexIDMap2Template<Index>;
105
+ using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
106
+
107
+ } // namespace faiss
@@ -23,6 +23,7 @@
23
23
  #include <faiss/IndexFlat.h>
24
24
  #include <faiss/impl/AuxIndexStructures.h>
25
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
26
27
 
27
28
  namespace faiss {
28
29
 
@@ -303,14 +304,20 @@ void IndexIVF::search(
303
304
  const float* x,
304
305
  idx_t k,
305
306
  float* distances,
306
- idx_t* labels) const {
307
+ idx_t* labels,
308
+ const SearchParameters* params_in) const {
307
309
  FAISS_THROW_IF_NOT(k > 0);
308
-
309
- const size_t nprobe = std::min(nlist, this->nprobe);
310
+ const IVFSearchParameters* params = nullptr;
311
+ if (params_in) {
312
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
313
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
314
+ }
315
+ const size_t nprobe =
316
+ std::min(nlist, params ? params->nprobe : this->nprobe);
310
317
  FAISS_THROW_IF_NOT(nprobe > 0);
311
318
 
312
319
  // search function for a subset of queries
313
- auto sub_search_func = [this, k, nprobe](
320
+ auto sub_search_func = [this, k, nprobe, params](
314
321
  idx_t n,
315
322
  const float* x,
316
323
  float* distances,
@@ -320,7 +327,13 @@ void IndexIVF::search(
320
327
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
321
328
 
322
329
  double t0 = getmillisecs();
323
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
330
+ quantizer->search(
331
+ n,
332
+ x,
333
+ nprobe,
334
+ coarse_dis.get(),
335
+ idx.get(),
336
+ params ? params->quantizer_params : nullptr);
324
337
 
325
338
  double t1 = getmillisecs();
326
339
  invlists->prefetch_lists(idx.get(), n * nprobe);
@@ -334,7 +347,7 @@ void IndexIVF::search(
334
347
  distances,
335
348
  labels,
336
349
  false,
337
- nullptr,
350
+ params,
338
351
  ivf_stats);
339
352
  double t2 = getmillisecs();
340
353
  ivf_stats->quantization_time += t1 - t0;
@@ -400,6 +413,19 @@ void IndexIVF::search_preassigned(
400
413
  FAISS_THROW_IF_NOT(nprobe > 0);
401
414
 
402
415
  idx_t max_codes = params ? params->max_codes : this->max_codes;
416
+ IDSelector* sel = params ? params->sel : nullptr;
417
+ const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
418
+ if (selr) {
419
+ if (selr->assume_sorted) {
420
+ sel = nullptr; // use special IDSelectorRange processing
421
+ } else {
422
+ selr = nullptr; // use generic processing
423
+ }
424
+ }
425
+
426
+ FAISS_THROW_IF_NOT_MSG(
427
+ !(sel && store_pairs),
428
+ "selector and store_pairs cannot be combined");
403
429
 
404
430
  size_t nlistv = 0, ndis = 0, nheap = 0;
405
431
 
@@ -421,7 +447,8 @@ void IndexIVF::search_preassigned(
421
447
 
422
448
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
423
449
  {
424
- InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
450
+ InvertedListScanner* scanner =
451
+ get_InvertedListScanner(store_pairs, sel);
425
452
  ScopeDeleter1<InvertedListScanner> del(scanner);
426
453
 
427
454
  /*****************************************************
@@ -492,6 +519,7 @@ void IndexIVF::search_preassigned(
492
519
 
493
520
  try {
494
521
  InvertedLists::ScopedCodes scodes(invlists, key);
522
+ const uint8_t* codes = scodes.get();
495
523
 
496
524
  std::unique_ptr<InvertedLists::ScopedIds> sids;
497
525
  const Index::idx_t* ids = nullptr;
@@ -501,8 +529,20 @@ void IndexIVF::search_preassigned(
501
529
  ids = sids->get();
502
530
  }
503
531
 
532
+ if (selr) { // IDSelectorRange
533
+ // restrict search to a section of the inverted list
534
+ size_t jmin, jmax;
535
+ selr->find_sorted_ids_bounds(list_size, ids, &jmin, &jmax);
536
+ list_size = jmax - jmin;
537
+ if (list_size == 0) {
538
+ return (size_t)0;
539
+ }
540
+ codes += jmin * code_size;
541
+ ids += jmin;
542
+ }
543
+
504
544
  nheap += scanner->scan_codes(
505
- list_size, scodes.get(), ids, simi, idxi, k);
545
+ list_size, codes, ids, simi, idxi, k);
506
546
 
507
547
  } catch (const std::exception& e) {
508
548
  std::lock_guard<std::mutex> lock(exception_mutex);
@@ -651,13 +691,23 @@ void IndexIVF::range_search(
651
691
  idx_t nx,
652
692
  const float* x,
653
693
  float radius,
654
- RangeSearchResult* result) const {
655
- const size_t nprobe = std::min(nlist, this->nprobe);
694
+ RangeSearchResult* result,
695
+ const SearchParameters* params_in) const {
696
+ const IVFSearchParameters* params = nullptr;
697
+ const SearchParameters* quantizer_params = nullptr;
698
+ if (params_in) {
699
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
700
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
701
+ quantizer_params = params->quantizer_params;
702
+ }
703
+ const size_t nprobe =
704
+ std::min(nlist, params ? params->nprobe : this->nprobe);
656
705
  std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
657
706
  std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
658
707
 
659
708
  double t0 = getmillisecs();
660
- quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
709
+ quantizer->search(
710
+ nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
661
711
  indexIVF_stats.quantization_time += getmillisecs() - t0;
662
712
 
663
713
  t0 = getmillisecs();
@@ -671,7 +721,7 @@ void IndexIVF::range_search(
671
721
  coarse_dis.get(),
672
722
  result,
673
723
  false,
674
- nullptr,
724
+ params,
675
725
  &indexIVF_stats);
676
726
 
677
727
  indexIVF_stats.search_time += getmillisecs() - t0;
@@ -689,7 +739,10 @@ void IndexIVF::range_search_preassigned(
689
739
  IndexIVFStats* stats) const {
690
740
  idx_t nprobe = params ? params->nprobe : this->nprobe;
691
741
  nprobe = std::min((idx_t)nlist, nprobe);
742
+ FAISS_THROW_IF_NOT(nprobe > 0);
743
+
692
744
  idx_t max_codes = params ? params->max_codes : this->max_codes;
745
+ IDSelector* sel = params ? params->sel : nullptr;
693
746
 
694
747
  size_t nlistv = 0, ndis = 0;
695
748
 
@@ -711,7 +764,7 @@ void IndexIVF::range_search_preassigned(
711
764
  {
712
765
  RangeSearchPartialResult pres(result);
713
766
  std::unique_ptr<InvertedListScanner> scanner(
714
- get_InvertedListScanner(store_pairs));
767
+ get_InvertedListScanner(store_pairs, sel));
715
768
  FAISS_THROW_IF_NOT(scanner.get());
716
769
  all_pres[omp_get_thread_num()] = &pres;
717
770
 
@@ -774,7 +827,6 @@ void IndexIVF::range_search_preassigned(
774
827
  }
775
828
  }
776
829
  } else if (parallel_mode == 2) {
777
- std::vector<RangeQueryResult*> all_qres(nx);
778
830
  RangeQueryResult* qres = nullptr;
779
831
 
780
832
  #pragma omp for schedule(dynamic)
@@ -782,7 +834,6 @@ void IndexIVF::range_search_preassigned(
782
834
  idx_t i = iik / (idx_t)nprobe;
783
835
  idx_t ik = iik % (idx_t)nprobe;
784
836
  if (qres == nullptr || qres->qno != i) {
785
- FAISS_ASSERT(!qres || i > qres->qno);
786
837
  qres = &pres.new_result(i);
787
838
  scanner->set_query(x + i * d);
788
839
  }
@@ -818,7 +869,8 @@ void IndexIVF::range_search_preassigned(
818
869
  }
819
870
 
820
871
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
821
- bool /*store_pairs*/) const {
872
+ bool /*store_pairs*/,
873
+ const IDSelector* /* sel */) const {
822
874
  return nullptr;
823
875
  }
824
876
 
@@ -846,6 +898,21 @@ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
846
898
  }
847
899
  }
848
900
 
901
+ bool IndexIVF::check_ids_sorted() const {
902
+ size_t nflip = 0;
903
+
904
+ for (size_t i = 0; i < nlist; i++) {
905
+ size_t list_size = invlists->list_size(i);
906
+ InvertedLists::ScopedIds ids(invlists, i);
907
+ for (size_t j = 0; j + 1 < list_size; j++) {
908
+ if (ids[j + 1] < ids[j]) {
909
+ nflip++;
910
+ }
911
+ }
912
+ }
913
+ return nflip == 0;
914
+ }
915
+
849
916
  /* standalone codec interface */
850
917
  size_t IndexIVF::sa_code_size() const {
851
918
  size_t coarse_size = coarse_code_size();
@@ -865,10 +932,15 @@ void IndexIVF::search_and_reconstruct(
865
932
  idx_t k,
866
933
  float* distances,
867
934
  idx_t* labels,
868
- float* recons) const {
869
- FAISS_THROW_IF_NOT(k > 0);
870
-
871
- const size_t nprobe = std::min(nlist, this->nprobe);
935
+ float* recons,
936
+ const SearchParameters* params_in) const {
937
+ const IVFSearchParameters* params = nullptr;
938
+ if (params_in) {
939
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
940
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
941
+ }
942
+ const size_t nprobe =
943
+ std::min(nlist, params ? params->nprobe : this->nprobe);
872
944
  FAISS_THROW_IF_NOT(nprobe > 0);
873
945
 
874
946
  idx_t* idx = new idx_t[n * nprobe];
@@ -890,7 +962,8 @@ void IndexIVF::search_and_reconstruct(
890
962
  coarse_dis,
891
963
  distances,
892
964
  labels,
893
- true /* store_pairs */);
965
+ true /* store_pairs */,
966
+ params);
894
967
  for (idx_t i = 0; i < n; ++i) {
895
968
  for (idx_t j = 0; j < k; ++j) {
896
969
  idx_t ij = i * k + j;
@@ -976,26 +1049,41 @@ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
976
1049
  // does nothing by default
977
1050
  }
978
1051
 
979
- void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
1052
+ bool check_compatible_for_merge_expensive_check = true;
1053
+
1054
+ void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
980
1055
  // minimal sanity checks
981
- FAISS_THROW_IF_NOT(other.d == d);
982
- FAISS_THROW_IF_NOT(other.nlist == nlist);
983
- FAISS_THROW_IF_NOT(other.code_size == code_size);
1056
+ const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
1057
+ FAISS_THROW_IF_NOT(other);
1058
+ FAISS_THROW_IF_NOT(other->d == d);
1059
+ FAISS_THROW_IF_NOT(other->nlist == nlist);
1060
+ FAISS_THROW_IF_NOT(quantizer->ntotal == other->quantizer->ntotal);
1061
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
984
1062
  FAISS_THROW_IF_NOT_MSG(
985
- typeid(*this) == typeid(other),
1063
+ typeid(*this) == typeid(*other),
986
1064
  "can only merge indexes of the same type");
987
1065
  FAISS_THROW_IF_NOT_MSG(
988
- this->direct_map.no() && other.direct_map.no(),
1066
+ this->direct_map.no() && other->direct_map.no(),
989
1067
  "merge direct_map not implemented");
990
- }
991
1068
 
992
- void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
993
- check_compatible_for_merge(other);
1069
+ if (check_compatible_for_merge_expensive_check) {
1070
+ std::vector<float> v(d), v2(d);
1071
+ for (size_t i = 0; i < nlist; i++) {
1072
+ quantizer->reconstruct(i, v.data());
1073
+ other->quantizer->reconstruct(i, v2.data());
1074
+ FAISS_THROW_IF_NOT_MSG(
1075
+ v == v2, "coarse quantizers should be the same");
1076
+ }
1077
+ }
1078
+ }
994
1079
 
995
- invlists->merge_from(other.invlists, add_id);
1080
+ void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
1081
+ check_compatible_for_merge(otherIndex);
1082
+ IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
1083
+ invlists->merge_from(other->invlists, add_id);
996
1084
 
997
- ntotal += other.ntotal;
998
- other.ntotal = 0;
1085
+ ntotal += other->ntotal;
1086
+ other->ntotal = 0;
999
1087
  }
1000
1088
 
1001
1089
  void IndexIVF::replace_invlists(InvertedLists* il, bool own) {