faiss 0.2.7 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (172) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/lib/faiss.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  12. data/vendor/faiss/faiss/AutoTune.h +0 -1
  13. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  14. data/vendor/faiss/faiss/Clustering.h +31 -21
  15. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  16. data/vendor/faiss/faiss/Index.cpp +1 -1
  17. data/vendor/faiss/faiss/Index.h +20 -5
  18. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  21. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  34. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  38. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  59. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  60. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  61. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  62. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  63. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  64. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  65. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  66. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  67. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  69. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  70. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  71. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  72. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  73. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  74. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  75. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  76. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  77. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  78. data/vendor/faiss/faiss/clone_index.h +3 -0
  79. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  80. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  81. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  82. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  90. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  92. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  93. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  97. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  98. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  99. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  101. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  104. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  105. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  106. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  107. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  108. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  111. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  113. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  119. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  125. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  126. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  127. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  128. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  129. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  133. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  135. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  136. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  137. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  138. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  139. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  140. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  141. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  142. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  143. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  144. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  145. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  146. data/vendor/faiss/faiss/utils/distances.h +81 -4
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  148. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  150. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  152. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  153. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  154. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  155. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  156. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  157. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  158. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  159. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  160. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  161. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  162. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  163. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  164. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  165. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  166. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  167. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  168. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  169. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  170. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  171. data/vendor/faiss/faiss/utils/utils.h +57 -20
  172. metadata +11 -4
@@ -21,49 +21,6 @@ namespace faiss {
21
21
 
22
22
  struct IndexHNSW;
23
23
 
24
- struct ReconstructFromNeighbors {
25
- typedef HNSW::storage_idx_t storage_idx_t;
26
-
27
- const IndexHNSW& index;
28
- size_t M; // number of neighbors
29
- size_t k; // number of codebook entries
30
- size_t nsq; // number of subvectors
31
- size_t code_size;
32
- int k_reorder; // nb to reorder. -1 = all
33
-
34
- std::vector<float> codebook; // size nsq * k * (M + 1)
35
-
36
- std::vector<uint8_t> codes; // size ntotal * code_size
37
- size_t ntotal;
38
- size_t d, dsub; // derived values
39
-
40
- explicit ReconstructFromNeighbors(
41
- const IndexHNSW& index,
42
- size_t k = 256,
43
- size_t nsq = 1);
44
-
45
- /// codes must be added in the correct order and the IndexHNSW
46
- /// must be populated and sorted
47
- void add_codes(size_t n, const float* x);
48
-
49
- size_t compute_distances(
50
- size_t n,
51
- const idx_t* shortlist,
52
- const float* query,
53
- float* distances) const;
54
-
55
- /// called by add_codes
56
- void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const;
57
-
58
- /// called by compute_distances
59
- void reconstruct(storage_idx_t i, float* x, float* tmp) const;
60
-
61
- void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const;
62
-
63
- /// get the M+1 -by-d table for neighbor coordinates for vector i
64
- void get_neighbor_table(storage_idx_t i, float* out) const;
65
- };
66
-
67
24
  /** The HNSW index is a normal random-access index with a HNSW
68
25
  * link structure built on top */
69
26
 
@@ -74,10 +31,8 @@ struct IndexHNSW : Index {
74
31
  HNSW hnsw;
75
32
 
76
33
  // the sequential storage
77
- bool own_fields;
78
- Index* storage;
79
-
80
- ReconstructFromNeighbors* reconstruct_from_neighbors;
34
+ bool own_fields = false;
35
+ Index* storage = nullptr;
81
36
 
82
37
  explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
38
  explicit IndexHNSW(Index* storage, int M = 32);
@@ -98,6 +53,13 @@ struct IndexHNSW : Index {
98
53
  idx_t* labels,
99
54
  const SearchParameters* params = nullptr) const override;
100
55
 
56
+ void range_search(
57
+ idx_t n,
58
+ const float* x,
59
+ float radius,
60
+ RangeSearchResult* result,
61
+ const SearchParameters* params = nullptr) const override;
62
+
101
63
  void reconstruct(idx_t key, float* recons) const override;
102
64
 
103
65
  void reset() override;
@@ -134,6 +96,8 @@ struct IndexHNSW : Index {
134
96
  void reorder_links();
135
97
 
136
98
  void link_singletons();
99
+
100
+ void permute_entries(const idx_t* perm);
137
101
  };
138
102
 
139
103
  /** Flat index topped with with a HNSW structure to access elements
@@ -150,7 +114,7 @@ struct IndexHNSWFlat : IndexHNSW {
150
114
  */
151
115
  struct IndexHNSWPQ : IndexHNSW {
152
116
  IndexHNSWPQ();
153
- IndexHNSWPQ(int d, int pq_m, int M);
117
+ IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
154
118
  void train(idx_t n, const float* x) override;
155
119
  };
156
120
 
@@ -9,31 +9,43 @@
9
9
 
10
10
  #include <faiss/IndexIDMap.h>
11
11
 
12
- #include <stdint.h>
13
12
  #include <cinttypes>
13
+ #include <cstdint>
14
14
  #include <cstdio>
15
15
  #include <limits>
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/impl/IDSelector.h>
20
19
  #include <faiss/utils/Heap.h>
21
20
  #include <faiss/utils/WorkerThread.h>
22
21
 
23
22
  namespace faiss {
24
23
 
24
+ namespace {
25
+
26
+ // IndexBinary needs to update the code_size when d is set...
27
+
28
+ void sync_d(Index* index) {}
29
+
30
+ void sync_d(IndexBinary* index) {
31
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
32
+ index->code_size = index->d / 8;
33
+ }
34
+
35
+ } // anonymous namespace
36
+
25
37
  /*****************************************************
26
38
  * IndexIDMap implementation
27
39
  *******************************************************/
28
40
 
29
41
  template <typename IndexT>
30
- IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
- : index(index), own_fields(false) {
42
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index) : index(index) {
32
43
  FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
44
  this->is_trained = index->is_trained;
34
45
  this->metric_type = index->metric_type;
35
46
  this->verbose = index->verbose;
36
47
  this->d = index->d;
48
+ sync_d(this);
37
49
  }
38
50
 
39
51
  template <typename IndexT>
@@ -71,6 +83,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
71
83
  this->ntotal = index->ntotal;
72
84
  }
73
85
 
86
+ namespace {
87
+
88
+ /// RAII object to reset the IDSelector in the params object
89
+ struct ScopedSelChange {
90
+ SearchParameters* params = nullptr;
91
+ IDSelector* old_sel = nullptr;
92
+
93
+ void set(SearchParameters* params_2, IDSelector* new_sel) {
94
+ this->params = params_2;
95
+ old_sel = params_2->sel;
96
+ params_2->sel = new_sel;
97
+ }
98
+ ~ScopedSelChange() {
99
+ if (params) {
100
+ params->sel = old_sel;
101
+ }
102
+ }
103
+ };
104
+
105
+ } // namespace
106
+
74
107
  template <typename IndexT>
75
108
  void IndexIDMapTemplate<IndexT>::search(
76
109
  idx_t n,
@@ -79,9 +112,26 @@ void IndexIDMapTemplate<IndexT>::search(
79
112
  typename IndexT::distance_t* distances,
80
113
  idx_t* labels,
81
114
  const SearchParameters* params) const {
82
- FAISS_THROW_IF_NOT_MSG(
83
- !params, "search params not supported for this index");
84
- index->search(n, x, k, distances, labels);
115
+ IDSelectorTranslated this_idtrans(this->id_map, nullptr);
116
+ ScopedSelChange sel_change;
117
+
118
+ if (params && params->sel) {
119
+ auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
120
+
121
+ if (!idtrans) {
122
+ /*
123
+ FAISS_THROW_IF_NOT_MSG(
124
+ idtrans,
125
+ "IndexIDMap requires an IDSelectorTranslated on input");
126
+ */
127
+ // then make an idtrans and force it into the SearchParameters
128
+ // (hence the const_cast)
129
+ auto params_non_const = const_cast<SearchParameters*>(params);
130
+ this_idtrans.sel = params->sel;
131
+ sel_change.set(params_non_const, &this_idtrans);
132
+ }
133
+ }
134
+ index->search(n, x, k, distances, labels, params);
85
135
  idx_t* li = labels;
86
136
  #pragma omp parallel for
87
137
  for (idx_t i = 0; i < n * k; i++) {
@@ -96,9 +146,16 @@ void IndexIDMapTemplate<IndexT>::range_search(
96
146
  typename IndexT::distance_t radius,
97
147
  RangeSearchResult* result,
98
148
  const SearchParameters* params) const {
99
- FAISS_THROW_IF_NOT_MSG(
100
- !params, "search params not supported for this index");
101
- index->range_search(n, x, radius, result);
149
+ if (params) {
150
+ SearchParameters internal_search_parameters;
151
+ IDSelectorTranslated id_selector_translated(id_map, params->sel);
152
+ internal_search_parameters.sel = &id_selector_translated;
153
+
154
+ index->range_search(n, x, radius, result, &internal_search_parameters);
155
+ } else {
156
+ index->range_search(n, x, radius, result);
157
+ }
158
+
102
159
  #pragma omp parallel for
103
160
  for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
161
  result->labels[i] = result->labels[i] < 0 ? result->labels[i]
@@ -106,26 +163,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
106
163
  }
107
164
  }
108
165
 
109
- namespace {
110
-
111
- struct IDTranslatedSelector : IDSelector {
112
- const std::vector<int64_t>& id_map;
113
- const IDSelector& sel;
114
- IDTranslatedSelector(
115
- const std::vector<int64_t>& id_map,
116
- const IDSelector& sel)
117
- : id_map(id_map), sel(sel) {}
118
- bool is_member(idx_t id) const override {
119
- return sel.is_member(id_map[id]);
120
- }
121
- };
122
-
123
- } // namespace
124
-
125
166
  template <typename IndexT>
126
167
  size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
168
  // remove in sub-index first
128
- IDTranslatedSelector sel2(id_map, sel);
169
+ IDSelectorTranslated sel2(id_map, &sel);
129
170
  size_t nremove = index->remove_ids(sel2);
130
171
 
131
172
  int64_t j = 0;
@@ -232,7 +273,7 @@ void IndexIDMap2Template<IndexT>::reconstruct(
232
273
  typename IndexT::component_t* recons) const {
233
274
  try {
234
275
  this->index->reconstruct(rev_map.at(key), recons);
235
- } catch (const std::out_of_range& e) {
276
+ } catch (const std::out_of_range&) {
236
277
  FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
278
  }
238
279
  }
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/Index.h>
11
11
  #include <faiss/IndexBinary.h>
12
+ #include <faiss/impl/IDSelector.h>
12
13
 
13
14
  #include <unordered_map>
14
15
  #include <vector>
@@ -21,8 +22,8 @@ struct IndexIDMapTemplate : IndexT {
21
22
  using component_t = typename IndexT::component_t;
22
23
  using distance_t = typename IndexT::distance_t;
23
24
 
24
- IndexT* index; ///! the sub-index
25
- bool own_fields; ///! whether pointers are deleted in destructo
25
+ IndexT* index = nullptr; ///! the sub-index
26
+ bool own_fields = false; ///! whether pointers are deleted in destructo
26
27
  std::vector<idx_t> id_map;
27
28
 
28
29
  explicit IndexIDMapTemplate(IndexT* index);
@@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
102
103
  using IndexIDMap2 = IndexIDMap2Template<Index>;
103
104
  using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
104
105
 
106
+ // IDSelector that translates the ids using an IDMap
107
+ struct IDSelectorTranslated : IDSelector {
108
+ const std::vector<int64_t>& id_map;
109
+ const IDSelector* sel;
110
+
111
+ IDSelectorTranslated(
112
+ const std::vector<int64_t>& id_map,
113
+ const IDSelector* sel)
114
+ : id_map(id_map), sel(sel) {}
115
+
116
+ IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
117
+ : id_map(index_idmap.id_map), sel(sel) {}
118
+
119
+ IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
120
+ : id_map(index_idmap.id_map), sel(sel) {}
121
+
122
+ bool is_member(idx_t id) const override {
123
+ return sel->is_member(id_map[id]);
124
+ }
125
+ };
126
+
105
127
  } // namespace faiss
@@ -11,6 +11,7 @@
11
11
 
12
12
  #include <omp.h>
13
13
  #include <cstdint>
14
+ #include <memory>
14
15
  #include <mutex>
15
16
 
16
17
  #include <algorithm>
@@ -45,7 +46,7 @@ Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
45
46
  cp.niter = 10;
46
47
  }
47
48
 
48
- Level1Quantizer::Level1Quantizer() {}
49
+ Level1Quantizer::Level1Quantizer() = default;
49
50
 
50
51
  Level1Quantizer::~Level1Quantizer() {
51
52
  if (own_fields) {
@@ -172,7 +173,7 @@ IndexIVF::IndexIVF(
172
173
  }
173
174
  }
174
175
 
175
- IndexIVF::IndexIVF() {}
176
+ IndexIVF::IndexIVF() = default;
176
177
 
177
178
  void IndexIVF::add(idx_t n, const float* x) {
178
179
  add_with_ids(n, x, nullptr);
@@ -202,7 +203,8 @@ void IndexIVF::add_core(
202
203
  idx_t n,
203
204
  const float* x,
204
205
  const idx_t* xids,
205
- const idx_t* coarse_idx) {
206
+ const idx_t* coarse_idx,
207
+ void* inverted_list_context) {
206
208
  // do some blocking to avoid excessive allocs
207
209
  idx_t bs = 65536;
208
210
  if (n > bs) {
@@ -217,7 +219,8 @@ void IndexIVF::add_core(
217
219
  i1 - i0,
218
220
  x + i0 * d,
219
221
  xids ? xids + i0 : nullptr,
220
- coarse_idx + i0);
222
+ coarse_idx + i0,
223
+ inverted_list_context);
221
224
  }
222
225
  return;
223
226
  }
@@ -248,7 +251,10 @@ void IndexIVF::add_core(
248
251
  if (list_no >= 0 && list_no % nt == rank) {
249
252
  idx_t id = xids ? xids[i] : ntotal + i;
250
253
  size_t ofs = invlists->add_entry(
251
- list_no, id, flat_codes.get() + i * code_size);
254
+ list_no,
255
+ id,
256
+ flat_codes.get() + i * code_size,
257
+ inverted_list_context);
252
258
 
253
259
  dm_adder.add(i, list_no, ofs);
254
260
 
@@ -375,7 +381,7 @@ void IndexIVF::search(
375
381
  indexIVF_stats.add(stats[slice]);
376
382
  }
377
383
  } else {
378
- // handle paralellization at level below (or don't run in parallel at
384
+ // handle parallelization at level below (or don't run in parallel at
379
385
  // all)
380
386
  sub_search_func(n, x, distances, labels, &indexIVF_stats);
381
387
  }
@@ -444,11 +450,13 @@ void IndexIVF::search_preassigned(
444
450
  : pmode == 1 ? nprobe > 1
445
451
  : nprobe * n > 1);
446
452
 
453
+ void* inverted_list_context =
454
+ params ? params->inverted_list_context : nullptr;
455
+
447
456
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
448
457
  {
449
- InvertedListScanner* scanner =
450
- get_InvertedListScanner(store_pairs, sel);
451
- ScopeDeleter1<InvertedListScanner> del(scanner);
458
+ std::unique_ptr<InvertedListScanner> scanner(
459
+ get_InvertedListScanner(store_pairs, sel));
452
460
 
453
461
  /*****************************************************
454
462
  * Depending on parallel_mode, there are two possible ways
@@ -507,7 +515,7 @@ void IndexIVF::search_preassigned(
507
515
  nlist);
508
516
 
509
517
  // don't waste time on empty lists
510
- if (invlists->is_empty(key)) {
518
+ if (invlists->is_empty(key, inverted_list_context)) {
511
519
  return (size_t)0;
512
520
  }
513
521
 
@@ -520,7 +528,7 @@ void IndexIVF::search_preassigned(
520
528
  size_t list_size = 0;
521
529
 
522
530
  std::unique_ptr<InvertedListsIterator> it(
523
- invlists->get_iterator(key));
531
+ invlists->get_iterator(key, inverted_list_context));
524
532
 
525
533
  nheap += scanner->iterate_codes(
526
534
  it.get(), simi, idxi, k, list_size);
@@ -539,7 +547,8 @@ void IndexIVF::search_preassigned(
539
547
  const idx_t* ids = nullptr;
540
548
 
541
549
  if (!store_pairs) {
542
- sids.reset(new InvertedLists::ScopedIds(invlists, key));
550
+ sids = std::make_unique<InvertedLists::ScopedIds>(
551
+ invlists, key);
543
552
  ids = sids->get();
544
553
  }
545
554
 
@@ -659,7 +668,6 @@ void IndexIVF::search_preassigned(
659
668
  #pragma omp for schedule(dynamic)
660
669
  for (int64_t ij = 0; ij < n * nprobe; ij++) {
661
670
  size_t i = ij / nprobe;
662
- size_t j = ij % nprobe;
663
671
 
664
672
  scanner->set_query(x + i * d);
665
673
  init_result(local_dis.data(), local_idx.data());
@@ -696,12 +704,13 @@ void IndexIVF::search_preassigned(
696
704
  }
697
705
  }
698
706
 
699
- if (ivf_stats) {
700
- ivf_stats->nq += n;
701
- ivf_stats->nlist += nlistv;
702
- ivf_stats->ndis += ndis;
703
- ivf_stats->nheap_updates += nheap;
707
+ if (ivf_stats == nullptr) {
708
+ ivf_stats = &indexIVF_stats;
704
709
  }
710
+ ivf_stats->nq += n;
711
+ ivf_stats->nlist += nlistv;
712
+ ivf_stats->ndis += ndis;
713
+ ivf_stats->nheap_updates += nheap;
705
714
  }
706
715
 
707
716
  void IndexIVF::range_search(
@@ -781,6 +790,9 @@ void IndexIVF::range_search_preassigned(
781
790
  : pmode == 1 ? nprobe > 1
782
791
  : nprobe * nx > 1);
783
792
 
793
+ void* inverted_list_context =
794
+ params ? params->inverted_list_context : nullptr;
795
+
784
796
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
785
797
  {
786
798
  RangeSearchPartialResult pres(result);
@@ -802,7 +814,7 @@ void IndexIVF::range_search_preassigned(
802
814
  ik,
803
815
  nlist);
804
816
 
805
- if (invlists->is_empty(key)) {
817
+ if (invlists->is_empty(key, inverted_list_context)) {
806
818
  return;
807
819
  }
808
820
 
@@ -811,7 +823,7 @@ void IndexIVF::range_search_preassigned(
811
823
  scanner->set_list(key, coarse_dis[i * nprobe + ik]);
812
824
  if (invlists->use_iterator) {
813
825
  std::unique_ptr<InvertedListsIterator> it(
814
- invlists->get_iterator(key));
826
+ invlists->get_iterator(key, inverted_list_context));
815
827
 
816
828
  scanner->iterate_codes_range(
817
829
  it.get(), radius, qres, list_size);
@@ -891,17 +903,18 @@ void IndexIVF::range_search_preassigned(
891
903
  }
892
904
  }
893
905
 
894
- if (stats) {
895
- stats->nq += nx;
896
- stats->nlist += nlistv;
897
- stats->ndis += ndis;
906
+ if (stats == nullptr) {
907
+ stats = &indexIVF_stats;
898
908
  }
909
+ stats->nq += nx;
910
+ stats->nlist += nlistv;
911
+ stats->ndis += ndis;
899
912
  }
900
913
 
901
914
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
902
915
  bool /*store_pairs*/,
903
916
  const IDSelector* /* sel */) const {
904
- return nullptr;
917
+ FAISS_THROW_MSG("get_InvertedListScanner not implemented");
905
918
  }
906
919
 
907
920
  void IndexIVF::reconstruct(idx_t key, float* recons) const {
@@ -973,14 +986,12 @@ void IndexIVF::search_and_reconstruct(
973
986
  std::min(nlist, params ? params->nprobe : this->nprobe);
974
987
  FAISS_THROW_IF_NOT(nprobe > 0);
975
988
 
976
- idx_t* idx = new idx_t[n * nprobe];
977
- ScopeDeleter<idx_t> del(idx);
978
- float* coarse_dis = new float[n * nprobe];
979
- ScopeDeleter<float> del2(coarse_dis);
989
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
990
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
980
991
 
981
- quantizer->search(n, x, nprobe, coarse_dis, idx);
992
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
982
993
 
983
- invlists->prefetch_lists(idx, n * nprobe);
994
+ invlists->prefetch_lists(idx.get(), n * nprobe);
984
995
 
985
996
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
986
997
  // and offset into `codes` for reconstruction
@@ -988,29 +999,94 @@ void IndexIVF::search_and_reconstruct(
988
999
  n,
989
1000
  x,
990
1001
  k,
991
- idx,
992
- coarse_dis,
1002
+ idx.get(),
1003
+ coarse_dis.get(),
993
1004
  distances,
994
1005
  labels,
995
1006
  true /* store_pairs */,
996
1007
  params);
997
- for (idx_t i = 0; i < n; ++i) {
998
- for (idx_t j = 0; j < k; ++j) {
999
- idx_t ij = i * k + j;
1000
- idx_t key = labels[ij];
1001
- float* reconstructed = recons + ij * d;
1002
- if (key < 0) {
1003
- // Fill with NaNs
1004
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
1005
- } else {
1006
- int list_no = lo_listno(key);
1007
- int offset = lo_offset(key);
1008
+ #pragma omp parallel for if (n * k > 1000)
1009
+ for (idx_t ij = 0; ij < n * k; ij++) {
1010
+ idx_t key = labels[ij];
1011
+ float* reconstructed = recons + ij * d;
1012
+ if (key < 0) {
1013
+ // Fill with NaNs
1014
+ memset(reconstructed, -1, sizeof(*reconstructed) * d);
1015
+ } else {
1016
+ int list_no = lo_listno(key);
1017
+ int offset = lo_offset(key);
1018
+
1019
+ // Update label to the actual id
1020
+ labels[ij] = invlists->get_single_id(list_no, offset);
1021
+
1022
+ reconstruct_from_offset(list_no, offset, reconstructed);
1023
+ }
1024
+ }
1025
+ }
1026
+
1027
+ void IndexIVF::search_and_return_codes(
1028
+ idx_t n,
1029
+ const float* x,
1030
+ idx_t k,
1031
+ float* distances,
1032
+ idx_t* labels,
1033
+ uint8_t* codes,
1034
+ bool include_listno,
1035
+ const SearchParameters* params_in) const {
1036
+ const IVFSearchParameters* params = nullptr;
1037
+ if (params_in) {
1038
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
1039
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1040
+ }
1041
+ const size_t nprobe =
1042
+ std::min(nlist, params ? params->nprobe : this->nprobe);
1043
+ FAISS_THROW_IF_NOT(nprobe > 0);
1044
+
1045
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1046
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1047
+
1048
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1049
+
1050
+ invlists->prefetch_lists(idx.get(), n * nprobe);
1051
+
1052
+ // search_preassigned() with `store_pairs` enabled to obtain the list_no
1053
+ // and offset into `codes` for reconstruction
1054
+ search_preassigned(
1055
+ n,
1056
+ x,
1057
+ k,
1058
+ idx.get(),
1059
+ coarse_dis.get(),
1060
+ distances,
1061
+ labels,
1062
+ true /* store_pairs */,
1063
+ params);
1064
+
1065
+ size_t code_size_1 = code_size;
1066
+ if (include_listno) {
1067
+ code_size_1 += coarse_code_size();
1068
+ }
1008
1069
 
1009
- // Update label to the actual id
1010
- labels[ij] = invlists->get_single_id(list_no, offset);
1070
+ #pragma omp parallel for if (n * k > 1000)
1071
+ for (idx_t ij = 0; ij < n * k; ij++) {
1072
+ idx_t key = labels[ij];
1073
+ uint8_t* code1 = codes + ij * code_size_1;
1011
1074
 
1012
- reconstruct_from_offset(list_no, offset, reconstructed);
1075
+ if (key < 0) {
1076
+ // Fill with 0xff
1077
+ memset(code1, -1, code_size_1);
1078
+ } else {
1079
+ int list_no = lo_listno(key);
1080
+ int offset = lo_offset(key);
1081
+ const uint8_t* cc = invlists->get_single_code(list_no, offset);
1082
+
1083
+ labels[ij] = invlists->get_single_id(list_no, offset);
1084
+
1085
+ if (include_listno) {
1086
+ encode_listno(list_no, code1);
1087
+ code1 += code_size_1 - code_size;
1013
1088
  }
1089
+ memcpy(code1, cc, code_size);
1014
1090
  }
1015
1091
  }
1016
1092
  }
@@ -1061,22 +1137,52 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
1061
1137
  }
1062
1138
 
1063
1139
  void IndexIVF::train(idx_t n, const float* x) {
1064
- if (verbose)
1140
+ if (verbose) {
1065
1141
  printf("Training level-1 quantizer\n");
1142
+ }
1066
1143
 
1067
1144
  train_q1(n, x, verbose, metric_type);
1068
1145
 
1069
- if (verbose)
1146
+ if (verbose) {
1070
1147
  printf("Training IVF residual\n");
1148
+ }
1149
+
1150
+ // optional subsampling
1151
+ idx_t max_nt = train_encoder_num_vectors();
1152
+ if (max_nt <= 0) {
1153
+ max_nt = (size_t)1 << 35;
1154
+ }
1155
+
1156
+ TransformedVectors tv(
1157
+ x, fvecs_maybe_subsample(d, (size_t*)&n, max_nt, x, verbose));
1158
+
1159
+ if (by_residual) {
1160
+ std::vector<idx_t> assign(n);
1161
+ quantizer->assign(n, tv.x, assign.data());
1162
+
1163
+ std::vector<float> residuals(n * d);
1164
+ quantizer->compute_residual_n(n, tv.x, residuals.data(), assign.data());
1165
+
1166
+ train_encoder(n, residuals.data(), assign.data());
1167
+ } else {
1168
+ train_encoder(n, tv.x, nullptr);
1169
+ }
1071
1170
 
1072
- train_residual(n, x);
1073
1171
  is_trained = true;
1074
1172
  }
1075
1173
 
1076
- void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
1077
- if (verbose)
1078
- printf("IndexIVF: no residual training\n");
1174
+ idx_t IndexIVF::train_encoder_num_vectors() const {
1175
+ return 0;
1176
+ }
1177
+
1178
+ void IndexIVF::train_encoder(
1179
+ idx_t /*n*/,
1180
+ const float* /*x*/,
1181
+ const idx_t* assign) {
1079
1182
  // does nothing by default
1183
+ if (verbose) {
1184
+ printf("IndexIVF: no residual training\n");
1185
+ }
1080
1186
  }
1081
1187
 
1082
1188
  bool check_compatible_for_merge_expensive_check = true;