faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -21,49 +21,6 @@ namespace faiss {
21
21
 
22
22
  struct IndexHNSW;
23
23
 
24
- struct ReconstructFromNeighbors {
25
- typedef HNSW::storage_idx_t storage_idx_t;
26
-
27
- const IndexHNSW& index;
28
- size_t M; // number of neighbors
29
- size_t k; // number of codebook entries
30
- size_t nsq; // number of subvectors
31
- size_t code_size;
32
- int k_reorder; // nb to reorder. -1 = all
33
-
34
- std::vector<float> codebook; // size nsq * k * (M + 1)
35
-
36
- std::vector<uint8_t> codes; // size ntotal * code_size
37
- size_t ntotal;
38
- size_t d, dsub; // derived values
39
-
40
- explicit ReconstructFromNeighbors(
41
- const IndexHNSW& index,
42
- size_t k = 256,
43
- size_t nsq = 1);
44
-
45
- /// codes must be added in the correct order and the IndexHNSW
46
- /// must be populated and sorted
47
- void add_codes(size_t n, const float* x);
48
-
49
- size_t compute_distances(
50
- size_t n,
51
- const idx_t* shortlist,
52
- const float* query,
53
- float* distances) const;
54
-
55
- /// called by add_codes
56
- void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const;
57
-
58
- /// called by compute_distances
59
- void reconstruct(storage_idx_t i, float* x, float* tmp) const;
60
-
61
- void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const;
62
-
63
- /// get the M+1 -by-d table for neighbor coordinates for vector i
64
- void get_neighbor_table(storage_idx_t i, float* out) const;
65
- };
66
-
67
24
  /** The HNSW index is a normal random-access index with a HNSW
68
25
  * link structure built on top */
69
26
 
@@ -74,10 +31,8 @@ struct IndexHNSW : Index {
74
31
  HNSW hnsw;
75
32
 
76
33
  // the sequential storage
77
- bool own_fields;
78
- Index* storage;
79
-
80
- ReconstructFromNeighbors* reconstruct_from_neighbors;
34
+ bool own_fields = false;
35
+ Index* storage = nullptr;
81
36
 
82
37
  explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
38
  explicit IndexHNSW(Index* storage, int M = 32);
@@ -98,6 +53,13 @@ struct IndexHNSW : Index {
98
53
  idx_t* labels,
99
54
  const SearchParameters* params = nullptr) const override;
100
55
 
56
+ void range_search(
57
+ idx_t n,
58
+ const float* x,
59
+ float radius,
60
+ RangeSearchResult* result,
61
+ const SearchParameters* params = nullptr) const override;
62
+
101
63
  void reconstruct(idx_t key, float* recons) const override;
102
64
 
103
65
  void reset() override;
@@ -134,6 +96,8 @@ struct IndexHNSW : Index {
134
96
  void reorder_links();
135
97
 
136
98
  void link_singletons();
99
+
100
+ void permute_entries(const idx_t* perm);
137
101
  };
138
102
 
139
103
  /** Flat index topped with with a HNSW structure to access elements
@@ -150,7 +114,7 @@ struct IndexHNSWFlat : IndexHNSW {
150
114
  */
151
115
  struct IndexHNSWPQ : IndexHNSW {
152
116
  IndexHNSWPQ();
153
- IndexHNSWPQ(int d, int pq_m, int M);
117
+ IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
154
118
  void train(idx_t n, const float* x) override;
155
119
  };
156
120
 
@@ -9,31 +9,43 @@
9
9
 
10
10
  #include <faiss/IndexIDMap.h>
11
11
 
12
- #include <stdint.h>
13
12
  #include <cinttypes>
13
+ #include <cstdint>
14
14
  #include <cstdio>
15
15
  #include <limits>
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/impl/IDSelector.h>
20
19
  #include <faiss/utils/Heap.h>
21
20
  #include <faiss/utils/WorkerThread.h>
22
21
 
23
22
  namespace faiss {
24
23
 
24
+ namespace {
25
+
26
+ // IndexBinary needs to update the code_size when d is set...
27
+
28
+ void sync_d(Index* index) {}
29
+
30
+ void sync_d(IndexBinary* index) {
31
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
32
+ index->code_size = index->d / 8;
33
+ }
34
+
35
+ } // anonymous namespace
36
+
25
37
  /*****************************************************
26
38
  * IndexIDMap implementation
27
39
  *******************************************************/
28
40
 
29
41
  template <typename IndexT>
30
- IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
- : index(index), own_fields(false) {
42
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index) : index(index) {
32
43
  FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
44
  this->is_trained = index->is_trained;
34
45
  this->metric_type = index->metric_type;
35
46
  this->verbose = index->verbose;
36
47
  this->d = index->d;
48
+ sync_d(this);
37
49
  }
38
50
 
39
51
  template <typename IndexT>
@@ -71,6 +83,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
71
83
  this->ntotal = index->ntotal;
72
84
  }
73
85
 
86
+ namespace {
87
+
88
+ /// RAII object to reset the IDSelector in the params object
89
+ struct ScopedSelChange {
90
+ SearchParameters* params = nullptr;
91
+ IDSelector* old_sel = nullptr;
92
+
93
+ void set(SearchParameters* params_2, IDSelector* new_sel) {
94
+ this->params = params_2;
95
+ old_sel = params_2->sel;
96
+ params_2->sel = new_sel;
97
+ }
98
+ ~ScopedSelChange() {
99
+ if (params) {
100
+ params->sel = old_sel;
101
+ }
102
+ }
103
+ };
104
+
105
+ } // namespace
106
+
74
107
  template <typename IndexT>
75
108
  void IndexIDMapTemplate<IndexT>::search(
76
109
  idx_t n,
@@ -79,9 +112,26 @@ void IndexIDMapTemplate<IndexT>::search(
79
112
  typename IndexT::distance_t* distances,
80
113
  idx_t* labels,
81
114
  const SearchParameters* params) const {
82
- FAISS_THROW_IF_NOT_MSG(
83
- !params, "search params not supported for this index");
84
- index->search(n, x, k, distances, labels);
115
+ IDSelectorTranslated this_idtrans(this->id_map, nullptr);
116
+ ScopedSelChange sel_change;
117
+
118
+ if (params && params->sel) {
119
+ auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
120
+
121
+ if (!idtrans) {
122
+ /*
123
+ FAISS_THROW_IF_NOT_MSG(
124
+ idtrans,
125
+ "IndexIDMap requires an IDSelectorTranslated on input");
126
+ */
127
+ // then make an idtrans and force it into the SearchParameters
128
+ // (hence the const_cast)
129
+ auto params_non_const = const_cast<SearchParameters*>(params);
130
+ this_idtrans.sel = params->sel;
131
+ sel_change.set(params_non_const, &this_idtrans);
132
+ }
133
+ }
134
+ index->search(n, x, k, distances, labels, params);
85
135
  idx_t* li = labels;
86
136
  #pragma omp parallel for
87
137
  for (idx_t i = 0; i < n * k; i++) {
@@ -96,9 +146,16 @@ void IndexIDMapTemplate<IndexT>::range_search(
96
146
  typename IndexT::distance_t radius,
97
147
  RangeSearchResult* result,
98
148
  const SearchParameters* params) const {
99
- FAISS_THROW_IF_NOT_MSG(
100
- !params, "search params not supported for this index");
101
- index->range_search(n, x, radius, result);
149
+ if (params) {
150
+ SearchParameters internal_search_parameters;
151
+ IDSelectorTranslated id_selector_translated(id_map, params->sel);
152
+ internal_search_parameters.sel = &id_selector_translated;
153
+
154
+ index->range_search(n, x, radius, result, &internal_search_parameters);
155
+ } else {
156
+ index->range_search(n, x, radius, result);
157
+ }
158
+
102
159
  #pragma omp parallel for
103
160
  for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
161
  result->labels[i] = result->labels[i] < 0 ? result->labels[i]
@@ -106,26 +163,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
106
163
  }
107
164
  }
108
165
 
109
- namespace {
110
-
111
- struct IDTranslatedSelector : IDSelector {
112
- const std::vector<int64_t>& id_map;
113
- const IDSelector& sel;
114
- IDTranslatedSelector(
115
- const std::vector<int64_t>& id_map,
116
- const IDSelector& sel)
117
- : id_map(id_map), sel(sel) {}
118
- bool is_member(idx_t id) const override {
119
- return sel.is_member(id_map[id]);
120
- }
121
- };
122
-
123
- } // namespace
124
-
125
166
  template <typename IndexT>
126
167
  size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
168
  // remove in sub-index first
128
- IDTranslatedSelector sel2(id_map, sel);
169
+ IDSelectorTranslated sel2(id_map, &sel);
129
170
  size_t nremove = index->remove_ids(sel2);
130
171
 
131
172
  int64_t j = 0;
@@ -232,7 +273,7 @@ void IndexIDMap2Template<IndexT>::reconstruct(
232
273
  typename IndexT::component_t* recons) const {
233
274
  try {
234
275
  this->index->reconstruct(rev_map.at(key), recons);
235
- } catch (const std::out_of_range& e) {
276
+ } catch (const std::out_of_range&) {
236
277
  FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
278
  }
238
279
  }
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/Index.h>
11
11
  #include <faiss/IndexBinary.h>
12
+ #include <faiss/impl/IDSelector.h>
12
13
 
13
14
  #include <unordered_map>
14
15
  #include <vector>
@@ -21,8 +22,8 @@ struct IndexIDMapTemplate : IndexT {
21
22
  using component_t = typename IndexT::component_t;
22
23
  using distance_t = typename IndexT::distance_t;
23
24
 
24
- IndexT* index; ///! the sub-index
25
- bool own_fields; ///! whether pointers are deleted in destructo
25
+ IndexT* index = nullptr; ///! the sub-index
26
+ bool own_fields = false; ///! whether pointers are deleted in destructo
26
27
  std::vector<idx_t> id_map;
27
28
 
28
29
  explicit IndexIDMapTemplate(IndexT* index);
@@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
102
103
  using IndexIDMap2 = IndexIDMap2Template<Index>;
103
104
  using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
104
105
 
106
+ // IDSelector that translates the ids using an IDMap
107
+ struct IDSelectorTranslated : IDSelector {
108
+ const std::vector<int64_t>& id_map;
109
+ const IDSelector* sel;
110
+
111
+ IDSelectorTranslated(
112
+ const std::vector<int64_t>& id_map,
113
+ const IDSelector* sel)
114
+ : id_map(id_map), sel(sel) {}
115
+
116
+ IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
117
+ : id_map(index_idmap.id_map), sel(sel) {}
118
+
119
+ IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
120
+ : id_map(index_idmap.id_map), sel(sel) {}
121
+
122
+ bool is_member(idx_t id) const override {
123
+ return sel->is_member(id_map[id]);
124
+ }
125
+ };
126
+
105
127
  } // namespace faiss
@@ -11,6 +11,7 @@
11
11
 
12
12
  #include <omp.h>
13
13
  #include <cstdint>
14
+ #include <memory>
14
15
  #include <mutex>
15
16
 
16
17
  #include <algorithm>
@@ -45,7 +46,7 @@ Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
45
46
  cp.niter = 10;
46
47
  }
47
48
 
48
- Level1Quantizer::Level1Quantizer() {}
49
+ Level1Quantizer::Level1Quantizer() = default;
49
50
 
50
51
  Level1Quantizer::~Level1Quantizer() {
51
52
  if (own_fields) {
@@ -172,7 +173,7 @@ IndexIVF::IndexIVF(
172
173
  }
173
174
  }
174
175
 
175
- IndexIVF::IndexIVF() {}
176
+ IndexIVF::IndexIVF() = default;
176
177
 
177
178
  void IndexIVF::add(idx_t n, const float* x) {
178
179
  add_with_ids(n, x, nullptr);
@@ -202,7 +203,8 @@ void IndexIVF::add_core(
202
203
  idx_t n,
203
204
  const float* x,
204
205
  const idx_t* xids,
205
- const idx_t* coarse_idx) {
206
+ const idx_t* coarse_idx,
207
+ void* inverted_list_context) {
206
208
  // do some blocking to avoid excessive allocs
207
209
  idx_t bs = 65536;
208
210
  if (n > bs) {
@@ -217,7 +219,8 @@ void IndexIVF::add_core(
217
219
  i1 - i0,
218
220
  x + i0 * d,
219
221
  xids ? xids + i0 : nullptr,
220
- coarse_idx + i0);
222
+ coarse_idx + i0,
223
+ inverted_list_context);
221
224
  }
222
225
  return;
223
226
  }
@@ -248,7 +251,10 @@ void IndexIVF::add_core(
248
251
  if (list_no >= 0 && list_no % nt == rank) {
249
252
  idx_t id = xids ? xids[i] : ntotal + i;
250
253
  size_t ofs = invlists->add_entry(
251
- list_no, id, flat_codes.get() + i * code_size);
254
+ list_no,
255
+ id,
256
+ flat_codes.get() + i * code_size,
257
+ inverted_list_context);
252
258
 
253
259
  dm_adder.add(i, list_no, ofs);
254
260
 
@@ -375,7 +381,7 @@ void IndexIVF::search(
375
381
  indexIVF_stats.add(stats[slice]);
376
382
  }
377
383
  } else {
378
- // handle paralellization at level below (or don't run in parallel at
384
+ // handle parallelization at level below (or don't run in parallel at
379
385
  // all)
380
386
  sub_search_func(n, x, distances, labels, &indexIVF_stats);
381
387
  }
@@ -444,11 +450,13 @@ void IndexIVF::search_preassigned(
444
450
  : pmode == 1 ? nprobe > 1
445
451
  : nprobe * n > 1);
446
452
 
453
+ void* inverted_list_context =
454
+ params ? params->inverted_list_context : nullptr;
455
+
447
456
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
448
457
  {
449
- InvertedListScanner* scanner =
450
- get_InvertedListScanner(store_pairs, sel);
451
- ScopeDeleter1<InvertedListScanner> del(scanner);
458
+ std::unique_ptr<InvertedListScanner> scanner(
459
+ get_InvertedListScanner(store_pairs, sel));
452
460
 
453
461
  /*****************************************************
454
462
  * Depending on parallel_mode, there are two possible ways
@@ -507,7 +515,7 @@ void IndexIVF::search_preassigned(
507
515
  nlist);
508
516
 
509
517
  // don't waste time on empty lists
510
- if (invlists->is_empty(key)) {
518
+ if (invlists->is_empty(key, inverted_list_context)) {
511
519
  return (size_t)0;
512
520
  }
513
521
 
@@ -520,7 +528,7 @@ void IndexIVF::search_preassigned(
520
528
  size_t list_size = 0;
521
529
 
522
530
  std::unique_ptr<InvertedListsIterator> it(
523
- invlists->get_iterator(key));
531
+ invlists->get_iterator(key, inverted_list_context));
524
532
 
525
533
  nheap += scanner->iterate_codes(
526
534
  it.get(), simi, idxi, k, list_size);
@@ -539,7 +547,8 @@ void IndexIVF::search_preassigned(
539
547
  const idx_t* ids = nullptr;
540
548
 
541
549
  if (!store_pairs) {
542
- sids.reset(new InvertedLists::ScopedIds(invlists, key));
550
+ sids = std::make_unique<InvertedLists::ScopedIds>(
551
+ invlists, key);
543
552
  ids = sids->get();
544
553
  }
545
554
 
@@ -659,7 +668,6 @@ void IndexIVF::search_preassigned(
659
668
  #pragma omp for schedule(dynamic)
660
669
  for (int64_t ij = 0; ij < n * nprobe; ij++) {
661
670
  size_t i = ij / nprobe;
662
- size_t j = ij % nprobe;
663
671
 
664
672
  scanner->set_query(x + i * d);
665
673
  init_result(local_dis.data(), local_idx.data());
@@ -696,12 +704,13 @@ void IndexIVF::search_preassigned(
696
704
  }
697
705
  }
698
706
 
699
- if (ivf_stats) {
700
- ivf_stats->nq += n;
701
- ivf_stats->nlist += nlistv;
702
- ivf_stats->ndis += ndis;
703
- ivf_stats->nheap_updates += nheap;
707
+ if (ivf_stats == nullptr) {
708
+ ivf_stats = &indexIVF_stats;
704
709
  }
710
+ ivf_stats->nq += n;
711
+ ivf_stats->nlist += nlistv;
712
+ ivf_stats->ndis += ndis;
713
+ ivf_stats->nheap_updates += nheap;
705
714
  }
706
715
 
707
716
  void IndexIVF::range_search(
@@ -781,6 +790,9 @@ void IndexIVF::range_search_preassigned(
781
790
  : pmode == 1 ? nprobe > 1
782
791
  : nprobe * nx > 1);
783
792
 
793
+ void* inverted_list_context =
794
+ params ? params->inverted_list_context : nullptr;
795
+
784
796
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
785
797
  {
786
798
  RangeSearchPartialResult pres(result);
@@ -802,7 +814,7 @@ void IndexIVF::range_search_preassigned(
802
814
  ik,
803
815
  nlist);
804
816
 
805
- if (invlists->is_empty(key)) {
817
+ if (invlists->is_empty(key, inverted_list_context)) {
806
818
  return;
807
819
  }
808
820
 
@@ -811,7 +823,7 @@ void IndexIVF::range_search_preassigned(
811
823
  scanner->set_list(key, coarse_dis[i * nprobe + ik]);
812
824
  if (invlists->use_iterator) {
813
825
  std::unique_ptr<InvertedListsIterator> it(
814
- invlists->get_iterator(key));
826
+ invlists->get_iterator(key, inverted_list_context));
815
827
 
816
828
  scanner->iterate_codes_range(
817
829
  it.get(), radius, qres, list_size);
@@ -891,17 +903,18 @@ void IndexIVF::range_search_preassigned(
891
903
  }
892
904
  }
893
905
 
894
- if (stats) {
895
- stats->nq += nx;
896
- stats->nlist += nlistv;
897
- stats->ndis += ndis;
906
+ if (stats == nullptr) {
907
+ stats = &indexIVF_stats;
898
908
  }
909
+ stats->nq += nx;
910
+ stats->nlist += nlistv;
911
+ stats->ndis += ndis;
899
912
  }
900
913
 
901
914
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
902
915
  bool /*store_pairs*/,
903
916
  const IDSelector* /* sel */) const {
904
- return nullptr;
917
+ FAISS_THROW_MSG("get_InvertedListScanner not implemented");
905
918
  }
906
919
 
907
920
  void IndexIVF::reconstruct(idx_t key, float* recons) const {
@@ -973,14 +986,12 @@ void IndexIVF::search_and_reconstruct(
973
986
  std::min(nlist, params ? params->nprobe : this->nprobe);
974
987
  FAISS_THROW_IF_NOT(nprobe > 0);
975
988
 
976
- idx_t* idx = new idx_t[n * nprobe];
977
- ScopeDeleter<idx_t> del(idx);
978
- float* coarse_dis = new float[n * nprobe];
979
- ScopeDeleter<float> del2(coarse_dis);
989
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
990
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
980
991
 
981
- quantizer->search(n, x, nprobe, coarse_dis, idx);
992
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
982
993
 
983
- invlists->prefetch_lists(idx, n * nprobe);
994
+ invlists->prefetch_lists(idx.get(), n * nprobe);
984
995
 
985
996
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
986
997
  // and offset into `codes` for reconstruction
@@ -988,29 +999,94 @@ void IndexIVF::search_and_reconstruct(
988
999
  n,
989
1000
  x,
990
1001
  k,
991
- idx,
992
- coarse_dis,
1002
+ idx.get(),
1003
+ coarse_dis.get(),
993
1004
  distances,
994
1005
  labels,
995
1006
  true /* store_pairs */,
996
1007
  params);
997
- for (idx_t i = 0; i < n; ++i) {
998
- for (idx_t j = 0; j < k; ++j) {
999
- idx_t ij = i * k + j;
1000
- idx_t key = labels[ij];
1001
- float* reconstructed = recons + ij * d;
1002
- if (key < 0) {
1003
- // Fill with NaNs
1004
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
1005
- } else {
1006
- int list_no = lo_listno(key);
1007
- int offset = lo_offset(key);
1008
+ #pragma omp parallel for if (n * k > 1000)
1009
+ for (idx_t ij = 0; ij < n * k; ij++) {
1010
+ idx_t key = labels[ij];
1011
+ float* reconstructed = recons + ij * d;
1012
+ if (key < 0) {
1013
+ // Fill with NaNs
1014
+ memset(reconstructed, -1, sizeof(*reconstructed) * d);
1015
+ } else {
1016
+ int list_no = lo_listno(key);
1017
+ int offset = lo_offset(key);
1018
+
1019
+ // Update label to the actual id
1020
+ labels[ij] = invlists->get_single_id(list_no, offset);
1021
+
1022
+ reconstruct_from_offset(list_no, offset, reconstructed);
1023
+ }
1024
+ }
1025
+ }
1026
+
1027
+ void IndexIVF::search_and_return_codes(
1028
+ idx_t n,
1029
+ const float* x,
1030
+ idx_t k,
1031
+ float* distances,
1032
+ idx_t* labels,
1033
+ uint8_t* codes,
1034
+ bool include_listno,
1035
+ const SearchParameters* params_in) const {
1036
+ const IVFSearchParameters* params = nullptr;
1037
+ if (params_in) {
1038
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
1039
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1040
+ }
1041
+ const size_t nprobe =
1042
+ std::min(nlist, params ? params->nprobe : this->nprobe);
1043
+ FAISS_THROW_IF_NOT(nprobe > 0);
1044
+
1045
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1046
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1047
+
1048
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1049
+
1050
+ invlists->prefetch_lists(idx.get(), n * nprobe);
1051
+
1052
+ // search_preassigned() with `store_pairs` enabled to obtain the list_no
1053
+ // and offset into `codes` for reconstruction
1054
+ search_preassigned(
1055
+ n,
1056
+ x,
1057
+ k,
1058
+ idx.get(),
1059
+ coarse_dis.get(),
1060
+ distances,
1061
+ labels,
1062
+ true /* store_pairs */,
1063
+ params);
1064
+
1065
+ size_t code_size_1 = code_size;
1066
+ if (include_listno) {
1067
+ code_size_1 += coarse_code_size();
1068
+ }
1008
1069
 
1009
- // Update label to the actual id
1010
- labels[ij] = invlists->get_single_id(list_no, offset);
1070
+ #pragma omp parallel for if (n * k > 1000)
1071
+ for (idx_t ij = 0; ij < n * k; ij++) {
1072
+ idx_t key = labels[ij];
1073
+ uint8_t* code1 = codes + ij * code_size_1;
1011
1074
 
1012
- reconstruct_from_offset(list_no, offset, reconstructed);
1075
+ if (key < 0) {
1076
+ // Fill with 0xff
1077
+ memset(code1, -1, code_size_1);
1078
+ } else {
1079
+ int list_no = lo_listno(key);
1080
+ int offset = lo_offset(key);
1081
+ const uint8_t* cc = invlists->get_single_code(list_no, offset);
1082
+
1083
+ labels[ij] = invlists->get_single_id(list_no, offset);
1084
+
1085
+ if (include_listno) {
1086
+ encode_listno(list_no, code1);
1087
+ code1 += code_size_1 - code_size;
1013
1088
  }
1089
+ memcpy(code1, cc, code_size);
1014
1090
  }
1015
1091
  }
1016
1092
  }
@@ -1061,22 +1137,52 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
1061
1137
  }
1062
1138
 
1063
1139
  void IndexIVF::train(idx_t n, const float* x) {
1064
- if (verbose)
1140
+ if (verbose) {
1065
1141
  printf("Training level-1 quantizer\n");
1142
+ }
1066
1143
 
1067
1144
  train_q1(n, x, verbose, metric_type);
1068
1145
 
1069
- if (verbose)
1146
+ if (verbose) {
1070
1147
  printf("Training IVF residual\n");
1148
+ }
1149
+
1150
+ // optional subsampling
1151
+ idx_t max_nt = train_encoder_num_vectors();
1152
+ if (max_nt <= 0) {
1153
+ max_nt = (size_t)1 << 35;
1154
+ }
1155
+
1156
+ TransformedVectors tv(
1157
+ x, fvecs_maybe_subsample(d, (size_t*)&n, max_nt, x, verbose));
1158
+
1159
+ if (by_residual) {
1160
+ std::vector<idx_t> assign(n);
1161
+ quantizer->assign(n, tv.x, assign.data());
1162
+
1163
+ std::vector<float> residuals(n * d);
1164
+ quantizer->compute_residual_n(n, tv.x, residuals.data(), assign.data());
1165
+
1166
+ train_encoder(n, residuals.data(), assign.data());
1167
+ } else {
1168
+ train_encoder(n, tv.x, nullptr);
1169
+ }
1071
1170
 
1072
- train_residual(n, x);
1073
1171
  is_trained = true;
1074
1172
  }
1075
1173
 
1076
- void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
1077
- if (verbose)
1078
- printf("IndexIVF: no residual training\n");
1174
+ idx_t IndexIVF::train_encoder_num_vectors() const {
1175
+ return 0;
1176
+ }
1177
+
1178
+ void IndexIVF::train_encoder(
1179
+ idx_t /*n*/,
1180
+ const float* /*x*/,
1181
+ const idx_t* assign) {
1079
1182
  // does nothing by default
1183
+ if (verbose) {
1184
+ printf("IndexIVF: no residual training\n");
1185
+ }
1080
1186
  }
1081
1187
 
1082
1188
  bool check_compatible_for_merge_expensive_check = true;