faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -14,6 +14,9 @@
14
14
 
15
15
  namespace faiss {
16
16
 
17
+ struct NormTableScaler;
18
+ struct SIMDResultHandlerToFloat;
19
+
17
20
  /** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now.
18
21
  *
19
22
  * The codes in the inverted lists are not stored sequentially but
@@ -28,6 +31,12 @@ namespace faiss {
28
31
  * 11: idem, collect results in reservoir
29
32
  * 12: optimizer int16 search, collect results in heap, uses qbs
30
33
  * 13: idem, collect results in reservoir
34
+ * 14: internally multithreaded implem over nq * nprobe
35
+ * 15: same with reservoir
36
+ *
37
+ * For range search, only 10 and 12 are supported.
38
+ * add 100 to the implem to force single-thread scanning (the coarse quantizer
39
+ * may still use multiple threads).
31
40
  */
32
41
 
33
42
  struct IndexIVFFastScan : IndexIVF {
@@ -45,7 +54,6 @@ struct IndexIVFFastScan : IndexIVF {
45
54
  int implem = 0;
46
55
  // skip some parts of the computation (for timing)
47
56
  int skip = 0;
48
- bool by_residual = false;
49
57
 
50
58
  // batching factors at search time (0 = default)
51
59
  int qbs = 0;
@@ -81,19 +89,24 @@ struct IndexIVFFastScan : IndexIVF {
81
89
 
82
90
  virtual bool lookup_table_is_3d() const = 0;
83
91
 
92
+ // compact way of conveying coarse quantization results
93
+ struct CoarseQuantized {
94
+ size_t nprobe;
95
+ const float* dis = nullptr;
96
+ const idx_t* ids = nullptr;
97
+ };
98
+
84
99
  virtual void compute_LUT(
85
100
  size_t n,
86
101
  const float* x,
87
- const idx_t* coarse_ids,
88
- const float* coarse_dis,
102
+ const CoarseQuantized& cq,
89
103
  AlignedTable<float>& dis_tables,
90
104
  AlignedTable<float>& biases) const = 0;
91
105
 
92
106
  void compute_LUT_uint8(
93
107
  size_t n,
94
108
  const float* x,
95
- const idx_t* coarse_ids,
96
- const float* coarse_dis,
109
+ const CoarseQuantized& cq,
97
110
  AlignedTable<uint8_t>& dis_tables,
98
111
  AlignedTable<uint16_t>& biases,
99
112
  float* normalizers) const;
@@ -106,7 +119,18 @@ struct IndexIVFFastScan : IndexIVF {
106
119
  idx_t* labels,
107
120
  const SearchParameters* params = nullptr) const override;
108
121
 
109
- /// will just fail
122
+ void search_preassigned(
123
+ idx_t n,
124
+ const float* x,
125
+ idx_t k,
126
+ const idx_t* assign,
127
+ const float* centroid_dis,
128
+ float* distances,
129
+ idx_t* labels,
130
+ bool store_pairs,
131
+ const IVFSearchParameters* params = nullptr,
132
+ IndexIVFStats* stats = nullptr) const override;
133
+
110
134
  void range_search(
111
135
  idx_t n,
112
136
  const float* x,
@@ -116,69 +140,75 @@ struct IndexIVFFastScan : IndexIVF {
116
140
 
117
141
  // internal search funcs
118
142
 
119
- template <bool is_max, class Scaler>
143
+ // dispatch to implementations and parallelize
120
144
  void search_dispatch_implem(
121
145
  idx_t n,
122
146
  const float* x,
123
147
  idx_t k,
124
148
  float* distances,
125
149
  idx_t* labels,
126
- const Scaler& scaler) const;
150
+ const CoarseQuantized& cq,
151
+ const NormTableScaler* scaler) const;
127
152
 
128
- template <class C, class Scaler>
153
+ void range_search_dispatch_implem(
154
+ idx_t n,
155
+ const float* x,
156
+ float radius,
157
+ RangeSearchResult& rres,
158
+ const CoarseQuantized& cq_in,
159
+ const NormTableScaler* scaler) const;
160
+
161
+ // impl 1 and 2 are just for verification
162
+ template <class C>
129
163
  void search_implem_1(
130
164
  idx_t n,
131
165
  const float* x,
132
166
  idx_t k,
133
167
  float* distances,
134
168
  idx_t* labels,
135
- const Scaler& scaler) const;
169
+ const CoarseQuantized& cq,
170
+ const NormTableScaler* scaler) const;
136
171
 
137
- template <class C, class Scaler>
172
+ template <class C>
138
173
  void search_implem_2(
139
174
  idx_t n,
140
175
  const float* x,
141
176
  idx_t k,
142
177
  float* distances,
143
178
  idx_t* labels,
144
- const Scaler& scaler) const;
179
+ const CoarseQuantized& cq,
180
+ const NormTableScaler* scaler) const;
145
181
 
146
182
  // implem 10 and 12 are not multithreaded internally, so
147
183
  // export search stats
148
- template <class C, class Scaler>
149
184
  void search_implem_10(
150
185
  idx_t n,
151
186
  const float* x,
152
- idx_t k,
153
- float* distances,
154
- idx_t* labels,
155
- int impl,
187
+ SIMDResultHandlerToFloat& handler,
188
+ const CoarseQuantized& cq,
156
189
  size_t* ndis_out,
157
190
  size_t* nlist_out,
158
- const Scaler& scaler) const;
191
+ const NormTableScaler* scaler) const;
159
192
 
160
- template <class C, class Scaler>
161
193
  void search_implem_12(
162
194
  idx_t n,
163
195
  const float* x,
164
- idx_t k,
165
- float* distances,
166
- idx_t* labels,
167
- int impl,
196
+ SIMDResultHandlerToFloat& handler,
197
+ const CoarseQuantized& cq,
168
198
  size_t* ndis_out,
169
199
  size_t* nlist_out,
170
- const Scaler& scaler) const;
200
+ const NormTableScaler* scaler) const;
171
201
 
172
202
  // implem 14 is multithreaded internally across nprobes and queries
173
- template <class C, class Scaler>
174
203
  void search_implem_14(
175
204
  idx_t n,
176
205
  const float* x,
177
206
  idx_t k,
178
207
  float* distances,
179
208
  idx_t* labels,
209
+ const CoarseQuantized& cq,
180
210
  int impl,
181
- const Scaler& scaler) const;
211
+ const NormTableScaler* scaler) const;
182
212
 
183
213
  // reconstruct vectors from packed invlists
184
214
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
@@ -36,15 +36,22 @@ IndexIVFFlat::IndexIVFFlat(
36
36
  MetricType metric)
37
37
  : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
38
38
  code_size = sizeof(float) * d;
39
+ by_residual = false;
40
+ }
41
+
42
+ IndexIVFFlat::IndexIVFFlat() {
43
+ by_residual = false;
39
44
  }
40
45
 
41
46
  void IndexIVFFlat::add_core(
42
47
  idx_t n,
43
48
  const float* x,
44
- const int64_t* xids,
45
- const int64_t* coarse_idx) {
49
+ const idx_t* xids,
50
+ const idx_t* coarse_idx,
51
+ void* inverted_list_context) {
46
52
  FAISS_THROW_IF_NOT(is_trained);
47
53
  FAISS_THROW_IF_NOT(coarse_idx);
54
+ FAISS_THROW_IF_NOT(!by_residual);
48
55
  assert(invlists);
49
56
  direct_map.check_can_add(xids);
50
57
 
@@ -64,8 +71,8 @@ void IndexIVFFlat::add_core(
64
71
  if (list_no >= 0 && list_no % nt == rank) {
65
72
  idx_t id = xids ? xids[i] : ntotal + i;
66
73
  const float* xi = x + i * d;
67
- size_t offset =
68
- invlists->add_entry(list_no, id, (const uint8_t*)xi);
74
+ size_t offset = invlists->add_entry(
75
+ list_no, id, (const uint8_t*)xi, inverted_list_context);
69
76
  dm_adder.add(i, list_no, offset);
70
77
  n_add++;
71
78
  } else if (rank == 0 && list_no == -1) {
@@ -89,6 +96,7 @@ void IndexIVFFlat::encode_vectors(
89
96
  const idx_t* list_nos,
90
97
  uint8_t* codes,
91
98
  bool include_listnos) const {
99
+ FAISS_THROW_IF_NOT(!by_residual);
92
100
  if (!include_listnos) {
93
101
  memcpy(codes, x, code_size * n);
94
102
  } else {
@@ -123,7 +131,9 @@ struct IVFFlatScanner : InvertedListScanner {
123
131
  size_t d;
124
132
 
125
133
  IVFFlatScanner(size_t d, bool store_pairs, const IDSelector* sel)
126
- : InvertedListScanner(store_pairs, sel), d(d) {}
134
+ : InvertedListScanner(store_pairs, sel), d(d) {
135
+ keep_max = is_similarity_metric(metric);
136
+ }
127
137
 
128
138
  const float* xi;
129
139
  void set_query(const float* query) override {
@@ -32,7 +32,8 @@ struct IndexIVFFlat : IndexIVF {
32
32
  idx_t n,
33
33
  const float* x,
34
34
  const idx_t* xids,
35
- const idx_t* precomputed_idx) override;
35
+ const idx_t* precomputed_idx,
36
+ void* inverted_list_context = nullptr) override;
36
37
 
37
38
  void encode_vectors(
38
39
  idx_t n,
@@ -50,7 +51,7 @@ struct IndexIVFFlat : IndexIVF {
50
51
 
51
52
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
52
53
 
53
- IndexIVFFlat() {}
54
+ IndexIVFFlat();
54
55
  };
55
56
 
56
57
  struct IndexIVFFlatDedup : IndexIVFFlat {
@@ -0,0 +1,172 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFIndependentQuantizer.h>
9
+ #include <faiss/IndexIVFPQ.h>
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/utils/utils.h>
12
+
13
+ namespace faiss {
14
+
15
+ IndexIVFIndependentQuantizer::IndexIVFIndependentQuantizer(
16
+ Index* quantizer,
17
+ IndexIVF* index_ivf,
18
+ VectorTransform* vt)
19
+ : Index(quantizer->d, index_ivf->metric_type),
20
+ quantizer(quantizer),
21
+ vt(vt),
22
+ index_ivf(index_ivf) {
23
+ if (vt) {
24
+ FAISS_THROW_IF_NOT_MSG(
25
+ vt->d_in == d && vt->d_out == index_ivf->d,
26
+ "invalid vector dimensions");
27
+ } else {
28
+ FAISS_THROW_IF_NOT_MSG(index_ivf->d == d, "invalid vector dimensions");
29
+ }
30
+
31
+ if (quantizer->is_trained && quantizer->ntotal != 0) {
32
+ FAISS_THROW_IF_NOT(quantizer->ntotal == index_ivf->nlist);
33
+ }
34
+ if (index_ivf->is_trained && vt) {
35
+ FAISS_THROW_IF_NOT(vt->is_trained);
36
+ }
37
+ ntotal = index_ivf->ntotal;
38
+ is_trained =
39
+ (quantizer->is_trained && quantizer->ntotal == index_ivf->nlist &&
40
+ (!vt || vt->is_trained) && index_ivf->is_trained);
41
+
42
+ // disable precomputed tables because they use the distances that are
43
+ // provided by the coarse quantizer (that are out of sync with the IVFPQ)
44
+ if (auto index_ivfpq = dynamic_cast<IndexIVFPQ*>(index_ivf)) {
45
+ index_ivfpq->use_precomputed_table = -1;
46
+ }
47
+ }
48
+
49
+ IndexIVFIndependentQuantizer::~IndexIVFIndependentQuantizer() {
50
+ if (own_fields) {
51
+ delete quantizer;
52
+ delete index_ivf;
53
+ delete vt;
54
+ }
55
+ }
56
+
57
+ namespace {
58
+
59
+ struct VTransformedVectors : TransformedVectors {
60
+ VTransformedVectors(const VectorTransform* vt, idx_t n, const float* x)
61
+ : TransformedVectors(x, vt ? vt->apply(n, x) : x) {}
62
+ };
63
+
64
+ struct SubsampledVectors : TransformedVectors {
65
+ SubsampledVectors(int d, idx_t* n, idx_t max_n, const float* x)
66
+ : TransformedVectors(
67
+ x,
68
+ fvecs_maybe_subsample(d, (size_t*)n, max_n, x, true)) {}
69
+ };
70
+
71
+ } // anonymous namespace
72
+
73
+ void IndexIVFIndependentQuantizer::add(idx_t n, const float* x) {
74
+ std::vector<float> D(n);
75
+ std::vector<idx_t> I(n);
76
+ quantizer->search(n, x, 1, D.data(), I.data());
77
+
78
+ VTransformedVectors tv(vt, n, x);
79
+
80
+ index_ivf->add_core(n, tv.x, nullptr, I.data());
81
+ }
82
+
83
+ void IndexIVFIndependentQuantizer::search(
84
+ idx_t n,
85
+ const float* x,
86
+ idx_t k,
87
+ float* distances,
88
+ idx_t* labels,
89
+ const SearchParameters* params) const {
90
+ FAISS_THROW_IF_NOT_MSG(!params, "search parameters not supported");
91
+ int nprobe = index_ivf->nprobe;
92
+ std::vector<float> D(n * nprobe);
93
+ std::vector<idx_t> I(n * nprobe);
94
+ quantizer->search(n, x, nprobe, D.data(), I.data());
95
+
96
+ VTransformedVectors tv(vt, n, x);
97
+
98
+ index_ivf->search_preassigned(
99
+ n, tv.x, k, I.data(), D.data(), distances, labels, false);
100
+ }
101
+
102
+ void IndexIVFIndependentQuantizer::reset() {
103
+ index_ivf->reset();
104
+ ntotal = 0;
105
+ }
106
+
107
+ void IndexIVFIndependentQuantizer::train(idx_t n, const float* x) {
108
+ // quantizer training
109
+ size_t nlist = index_ivf->nlist;
110
+ Level1Quantizer l1(quantizer, nlist);
111
+ l1.train_q1(n, x, verbose, metric_type);
112
+
113
+ // train the VectorTransform
114
+ if (vt && !vt->is_trained) {
115
+ if (verbose) {
116
+ printf("IndexIVFIndependentQuantizer: train the VectorTransform\n");
117
+ }
118
+ vt->train(n, x);
119
+ }
120
+
121
+ // get the centroids from the quantizer, transform them and
122
+ // add them to the index_ivf's quantizer
123
+ if (verbose) {
124
+ printf("IndexIVFIndependentQuantizer: extract the main quantizer centroids\n");
125
+ }
126
+ std::vector<float> centroids(nlist * d);
127
+ quantizer->reconstruct_n(0, nlist, centroids.data());
128
+ VTransformedVectors tcent(vt, nlist, centroids.data());
129
+
130
+ if (verbose) {
131
+ printf("IndexIVFIndependentQuantizer: add centroids to the secondary quantizer\n");
132
+ }
133
+ if (!index_ivf->quantizer->is_trained) {
134
+ index_ivf->quantizer->train(nlist, tcent.x);
135
+ }
136
+ index_ivf->quantizer->add(nlist, tcent.x);
137
+
138
+ // train the payload
139
+
140
+ // optional subsampling
141
+ idx_t max_nt = index_ivf->train_encoder_num_vectors();
142
+ if (max_nt <= 0) {
143
+ max_nt = (size_t)1 << 35;
144
+ }
145
+ SubsampledVectors sv(index_ivf->d, &n, max_nt, x);
146
+
147
+ // transform subsampled vectors
148
+ VTransformedVectors tv(vt, n, sv.x);
149
+
150
+ if (verbose) {
151
+ printf("IndexIVFIndependentQuantizer: train encoder\n");
152
+ }
153
+
154
+ if (index_ivf->by_residual) {
155
+ // assign with quantizer
156
+ std::vector<idx_t> assign(n);
157
+ quantizer->assign(n, sv.x, assign.data());
158
+
159
+ // compute residual with IVF quantizer
160
+ std::vector<float> residuals(n * index_ivf->d);
161
+ index_ivf->quantizer->compute_residual_n(
162
+ n, tv.x, residuals.data(), assign.data());
163
+
164
+ index_ivf->train_encoder(n, residuals.data(), assign.data());
165
+ } else {
166
+ index_ivf->train_encoder(n, tv.x, nullptr);
167
+ }
168
+ index_ivf->is_trained = true;
169
+ is_trained = true;
170
+ }
171
+
172
+ } // namespace faiss
@@ -0,0 +1,56 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/IndexIVF.h>
11
+ #include <faiss/VectorTransform.h>
12
+
13
+ namespace faiss {
14
+
15
+ /** An IVF index with a quantizer that has a different input dimension from the
16
+ * payload size. The vectors to encode are obtained from the input vectors by a
17
+ * VectorTransform.
18
+ */
19
+ struct IndexIVFIndependentQuantizer : Index {
20
+ /// quantizer is fed directly with the input vectors
21
+ Index* quantizer = nullptr;
22
+
23
+ /// transform before the IVF vectors are applied
24
+ VectorTransform* vt = nullptr;
25
+
26
+ /// the IVF index, controls nlist and nprobe
27
+ IndexIVF* index_ivf = nullptr;
28
+
29
+ /// whether *this owns the 3 fields
30
+ bool own_fields = false;
31
+
32
+ IndexIVFIndependentQuantizer(
33
+ Index* quantizer,
34
+ IndexIVF* index_ivf,
35
+ VectorTransform* vt = nullptr);
36
+
37
+ IndexIVFIndependentQuantizer() {}
38
+
39
+ void train(idx_t n, const float* x) override;
40
+
41
+ void add(idx_t n, const float* x) override;
42
+
43
+ void search(
44
+ idx_t n,
45
+ const float* x,
46
+ idx_t k,
47
+ float* distances,
48
+ idx_t* labels,
49
+ const SearchParameters* params = nullptr) const override;
50
+
51
+ void reset() override;
52
+
53
+ ~IndexIVFIndependentQuantizer() override;
54
+ };
55
+
56
+ } // namespace faiss