faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <memory>
11
11
 
12
+ #include <faiss/IndexIVFFastScan.h>
12
13
  #include <faiss/IndexIVFPQ.h>
13
14
  #include <faiss/impl/ProductQuantizer.h>
14
15
  #include <faiss/utils/AlignedTable.h>
@@ -31,36 +32,20 @@ namespace faiss {
31
32
  * 13: idem, collect results in reservoir
32
33
  */
33
34
 
34
- struct IndexIVFPQFastScan : IndexIVF {
35
- bool by_residual; ///< Encode residual or plain vector?
35
+ struct IndexIVFPQFastScan : IndexIVFFastScan {
36
36
  ProductQuantizer pq; ///< produces the codes
37
37
 
38
- // size of the kernel
39
- int bbs; // set at build time
40
-
41
- // M rounded up to a multiple of 2
42
- size_t M2;
43
-
44
38
  /// precomputed tables management
45
39
  int use_precomputed_table = 0;
46
40
  /// if use_precompute_table size (nlist, pq.M, pq.ksub)
47
41
  AlignedTable<float> precomputed_table;
48
42
 
49
- // search-time implementation
50
- int implem = 0;
51
- // skip some parts of the computation (for timing)
52
- int skip = 0;
53
-
54
- // batching factors at search time (0 = default)
55
- int qbs = 0;
56
- size_t qbs2 = 0;
57
-
58
43
  IndexIVFPQFastScan(
59
44
  Index* quantizer,
60
45
  size_t d,
61
46
  size_t nlist,
62
47
  size_t M,
63
- size_t nbits_per_idx,
48
+ size_t nbits,
64
49
  MetricType metric = METRIC_L2,
65
50
  int bbs = 32);
66
51
 
@@ -69,9 +54,6 @@ struct IndexIVFPQFastScan : IndexIVF {
69
54
  // built from an IndexIVFPQ
70
55
  explicit IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs = 32);
71
56
 
72
- /// orig's inverted lists (for debugging)
73
- InvertedLists* orig_invlists = nullptr;
74
-
75
57
  void train_residual(idx_t n, const float* x) override;
76
58
 
77
59
  /// build precomputed table, possibly updating use_precomputed_table
@@ -86,106 +68,19 @@ struct IndexIVFPQFastScan : IndexIVF {
86
68
  uint8_t* codes,
87
69
  bool include_listno = false) const override;
88
70
 
89
- void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
90
-
91
- void search(
92
- idx_t n,
93
- const float* x,
94
- idx_t k,
95
- float* distances,
96
- idx_t* labels) const override;
97
-
98
71
  // prepare look-up tables
99
72
 
73
+ bool lookup_table_is_3d() const override;
74
+
100
75
  void compute_LUT(
101
76
  size_t n,
102
77
  const float* x,
103
78
  const idx_t* coarse_ids,
104
79
  const float* coarse_dis,
105
80
  AlignedTable<float>& dis_tables,
106
- AlignedTable<float>& biases) const;
107
-
108
- void compute_LUT_uint8(
109
- size_t n,
110
- const float* x,
111
- const idx_t* coarse_ids,
112
- const float* coarse_dis,
113
- AlignedTable<uint8_t>& dis_tables,
114
- AlignedTable<uint16_t>& biases,
115
- float* normalizers) const;
116
-
117
- // internal search funcs
81
+ AlignedTable<float>& biases) const override;
118
82
 
119
- template <bool is_max>
120
- void search_dispatch_implem(
121
- idx_t n,
122
- const float* x,
123
- idx_t k,
124
- float* distances,
125
- idx_t* labels) const;
126
-
127
- template <class C>
128
- void search_implem_1(
129
- idx_t n,
130
- const float* x,
131
- idx_t k,
132
- float* distances,
133
- idx_t* labels) const;
134
-
135
- template <class C>
136
- void search_implem_2(
137
- idx_t n,
138
- const float* x,
139
- idx_t k,
140
- float* distances,
141
- idx_t* labels) const;
142
-
143
- // implem 10 and 12 are not multithreaded internally, so
144
- // export search stats
145
- template <class C>
146
- void search_implem_10(
147
- idx_t n,
148
- const float* x,
149
- idx_t k,
150
- float* distances,
151
- idx_t* labels,
152
- int impl,
153
- size_t* ndis_out,
154
- size_t* nlist_out) const;
155
-
156
- template <class C>
157
- void search_implem_12(
158
- idx_t n,
159
- const float* x,
160
- idx_t k,
161
- float* distances,
162
- idx_t* labels,
163
- int impl,
164
- size_t* ndis_out,
165
- size_t* nlist_out) const;
83
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
166
84
  };
167
85
 
168
- struct IVFFastScanStats {
169
- uint64_t times[10];
170
- uint64_t t_compute_distance_tables, t_round;
171
- uint64_t t_copy_pack, t_scan, t_to_flat;
172
- uint64_t reservoir_times[4];
173
-
174
- double Mcy_at(int i) {
175
- return times[i] / (1000 * 1000.0);
176
- }
177
-
178
- double Mcy_reservoir_at(int i) {
179
- return reservoir_times[i] / (1000 * 1000.0);
180
- }
181
- IVFFastScanStats() {
182
- reset();
183
- }
184
- void reset() {
185
- memset(this, 0, sizeof(*this));
186
- }
187
- };
188
-
189
- FAISS_API extern IVFFastScanStats IVFFastScan_stats;
190
-
191
86
  } // namespace faiss
@@ -201,11 +201,11 @@ void IndexIVFPQR::reconstruct_from_offset(
201
201
  }
202
202
  }
203
203
 
204
- void IndexIVFPQR::merge_from(IndexIVF& other_in, idx_t add_id) {
205
- IndexIVFPQR* other = dynamic_cast<IndexIVFPQR*>(&other_in);
204
+ void IndexIVFPQR::merge_from(Index& otherIndex, idx_t add_id) {
205
+ IndexIVFPQR* other = dynamic_cast<IndexIVFPQR*>(&otherIndex);
206
206
  FAISS_THROW_IF_NOT(other);
207
207
 
208
- IndexIVF::merge_from(other_in, add_id);
208
+ IndexIVF::merge_from(otherIndex, add_id);
209
209
 
210
210
  refine_codes.insert(
211
211
  refine_codes.end(),
@@ -51,7 +51,7 @@ struct IndexIVFPQR : IndexIVFPQ {
51
51
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
52
52
  const override;
53
53
 
54
- void merge_from(IndexIVF& other, idx_t add_id) override;
54
+ void merge_from(Index& otherIndex, idx_t add_id) override;
55
55
 
56
56
  void search_preassigned(
57
57
  idx_t n,
@@ -13,6 +13,8 @@
13
13
  #include <algorithm>
14
14
  #include <memory>
15
15
 
16
+ #include <faiss/IndexLSH.h>
17
+ #include <faiss/IndexPreTransform.h>
16
18
  #include <faiss/VectorTransform.h>
17
19
  #include <faiss/impl/AuxIndexStructures.h>
18
20
  #include <faiss/impl/FaissAssert.h>
@@ -31,7 +33,6 @@ IndexIVFSpectralHash::IndexIVFSpectralHash(
31
33
  nbit(nbit),
32
34
  period(period),
33
35
  threshold_type(Thresh_global) {
34
- FAISS_THROW_IF_NOT(code_size % 4 == 0);
35
36
  RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
36
37
  rr->init(1234);
37
38
  vt = rr;
@@ -151,8 +152,8 @@ void binarize_with_freq(
151
152
  memset(codes, 0, (nbit + 7) / 8);
152
153
  for (size_t i = 0; i < nbit; i++) {
153
154
  float xf = (x[i] - c[i]);
154
- int xi = int(floor(xf * freq));
155
- int bit = xi & 1;
155
+ int64_t xi = int64_t(floor(xf * freq));
156
+ int64_t bit = xi & 1;
156
157
  codes[i >> 3] |= bit << (i & 7);
157
158
  }
158
159
  }
@@ -167,35 +168,33 @@ void IndexIVFSpectralHash::encode_vectors(
167
168
  bool include_listnos) const {
168
169
  FAISS_THROW_IF_NOT(is_trained);
169
170
  float freq = 2.0 / period;
170
-
171
- FAISS_THROW_IF_NOT_MSG(!include_listnos, "listnos encoding not supported");
171
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
172
172
 
173
173
  // transform with vt
174
174
  std::unique_ptr<float[]> x(vt->apply(n, x_in));
175
175
 
176
- #pragma omp parallel
177
- {
178
- std::vector<float> zero(nbit);
176
+ std::vector<float> zero(nbit);
179
177
 
180
- // each thread takes care of a subset of lists
181
178
  #pragma omp for
182
- for (idx_t i = 0; i < n; i++) {
183
- int64_t list_no = list_nos[i];
184
-
185
- if (list_no >= 0) {
186
- const float* c;
187
- if (threshold_type == Thresh_global) {
188
- c = zero.data();
189
- } else {
190
- c = trained.data() + list_no * nbit;
191
- }
192
- binarize_with_freq(
193
- nbit,
194
- freq,
195
- x.get() + i * nbit,
196
- c,
197
- codes + i * code_size);
179
+ for (idx_t i = 0; i < n; i++) {
180
+ int64_t list_no = list_nos[i];
181
+ uint8_t* code = codes + i * (code_size + coarse_size);
182
+
183
+ if (list_no >= 0) {
184
+ if (coarse_size) {
185
+ encode_listno(list_no, code);
186
+ }
187
+ const float* c;
188
+
189
+ if (threshold_type == Thresh_global) {
190
+ c = zero.data();
191
+ } else {
192
+ c = trained.data() + list_no * nbit;
198
193
  }
194
+ binarize_with_freq(
195
+ nbit, freq, x.get() + i * nbit, c, code + coarse_size);
196
+ } else {
197
+ memset(code, 0, code_size + coarse_size);
199
198
  }
200
199
  }
201
200
  }
@@ -206,9 +205,7 @@ template <class HammingComputer>
206
205
  struct IVFScanner : InvertedListScanner {
207
206
  // copied from index structure
208
207
  const IndexIVFSpectralHash* index;
209
- size_t code_size;
210
208
  size_t nbit;
211
- bool store_pairs;
212
209
 
213
210
  float period, freq;
214
211
  std::vector<float> q;
@@ -220,15 +217,16 @@ struct IVFScanner : InvertedListScanner {
220
217
 
221
218
  IVFScanner(const IndexIVFSpectralHash* index, bool store_pairs)
222
219
  : index(index),
223
- code_size(index->code_size),
224
220
  nbit(index->nbit),
225
- store_pairs(store_pairs),
226
221
  period(index->period),
227
222
  freq(2.0 / index->period),
228
223
  q(nbit),
229
224
  zero(nbit),
230
- qcode(code_size),
231
- hc(qcode.data(), code_size) {}
225
+ qcode(index->code_size),
226
+ hc(qcode.data(), index->code_size) {
227
+ this->store_pairs = store_pairs;
228
+ this->code_size = index->code_size;
229
+ }
232
230
 
233
231
  void set_query(const float* query) override {
234
232
  FAISS_THROW_IF_NOT(query);
@@ -241,8 +239,6 @@ struct IVFScanner : InvertedListScanner {
241
239
  }
242
240
  }
243
241
 
244
- idx_t list_no;
245
-
246
242
  void set_list(idx_t list_no, float /*coarse_dis*/) override {
247
243
  this->list_no = list_no;
248
244
  if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
@@ -297,7 +293,9 @@ struct IVFScanner : InvertedListScanner {
297
293
  } // anonymous namespace
298
294
 
299
295
  InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
300
- bool store_pairs) const {
296
+ bool store_pairs,
297
+ const IDSelector* sel) const {
298
+ FAISS_THROW_IF_NOT(!sel);
301
299
  switch (code_size) {
302
300
  #define HANDLE_CODE_SIZE(cs) \
303
301
  case cs: \
@@ -310,13 +308,38 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
310
308
  HANDLE_CODE_SIZE(64);
311
309
  #undef HANDLE_CODE_SIZE
312
310
  default:
313
- if (code_size % 4 == 0) {
314
- return new IVFScanner<HammingComputerDefault>(
315
- this, store_pairs);
316
- } else {
317
- FAISS_THROW_MSG("not supported");
318
- }
311
+ return new IVFScanner<HammingComputerDefault>(this, store_pairs);
312
+ }
313
+ }
314
+
315
+ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
316
+ FAISS_THROW_IF_NOT(vt_in->d_out == nbit);
317
+ FAISS_THROW_IF_NOT(vt_in->d_in == d);
318
+ if (own_fields) {
319
+ delete vt;
319
320
  }
321
+ vt = vt_in;
322
+ threshold_type = Thresh_global;
323
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist &&
324
+ vt->is_trained;
325
+ own_fields = own;
326
+ }
327
+
328
+ /*
329
+ Check that the encoder is a single vector transform followed by a LSH
330
+ that just does thresholding.
331
+ If this is not the case, the linear transform + threhsolds of the IndexLSH
332
+ should be merged into the VectorTransform (which is feasible).
333
+ */
334
+
335
+ void IndexIVFSpectralHash::replace_vt(IndexPreTransform* encoder, bool own) {
336
+ FAISS_THROW_IF_NOT(encoder->chain.size() == 1);
337
+ auto sub_index = dynamic_cast<IndexLSH*>(encoder->index);
338
+ FAISS_THROW_IF_NOT_MSG(sub_index, "final index should be LSH");
339
+ FAISS_THROW_IF_NOT(sub_index->nbits == nbit);
340
+ FAISS_THROW_IF_NOT(!sub_index->rotate_data);
341
+ FAISS_THROW_IF_NOT(!sub_index->train_thresholds);
342
+ replace_vt(encoder->chain[0], own);
320
343
  }
321
344
 
322
345
  } // namespace faiss
@@ -17,6 +17,7 @@
17
17
  namespace faiss {
18
18
 
19
19
  struct VectorTransform;
20
+ struct IndexPreTransform;
20
21
 
21
22
  /** Inverted list that stores binary codes of size nbit. Before the
22
23
  * binary conversion, the dimension of the vectors is transformed from
@@ -25,23 +26,29 @@ struct VectorTransform;
25
26
  * Each coordinate is subtracted from a value determined by
26
27
  * threshold_type, and split into intervals of size period. Half of
27
28
  * the interval is a 0 bit, the other half a 1.
29
+ *
28
30
  */
29
31
  struct IndexIVFSpectralHash : IndexIVF {
30
- VectorTransform* vt; // transformation from d to nbit dim
32
+ /// transformation from d to nbit dim
33
+ VectorTransform* vt;
34
+ /// own the vt
31
35
  bool own_fields;
32
36
 
37
+ /// nb of bits of the binary signature
33
38
  int nbit;
39
+ /// interval size for 0s and 1s
34
40
  float period;
35
41
 
36
42
  enum ThresholdType {
37
- Thresh_global,
38
- Thresh_centroid,
39
- Thresh_centroid_half,
40
- Thresh_median
43
+ Thresh_global, ///< global threshold at 0
44
+ Thresh_centroid, ///< compare to centroid
45
+ Thresh_centroid_half, ///< central interval around centroid
46
+ Thresh_median ///< median of training set
41
47
  };
42
48
  ThresholdType threshold_type;
43
49
 
44
- // size nlist * nbit or 0 if Thresh_global
50
+ /// Trained threshold.
51
+ /// size nlist * nbit or 0 if Thresh_global
45
52
  std::vector<float> trained;
46
53
 
47
54
  IndexIVFSpectralHash(
@@ -63,7 +70,16 @@ struct IndexIVFSpectralHash : IndexIVF {
63
70
  bool include_listnos = false) const override;
64
71
 
65
72
  InvertedListScanner* get_InvertedListScanner(
66
- bool store_pairs) const override;
73
+ bool store_pairs,
74
+ const IDSelector* sel) const override;
75
+
76
+ /** replace the vector transform for an empty (and possibly untrained) index
77
+ */
78
+ void replace_vt(VectorTransform* vt, bool own = false);
79
+
80
+ /** convenience function to get the VT from an index constucted by an
81
+ * index_factory (should end in "LSH") */
82
+ void replace_vt(IndexPreTransform* index, bool own = false);
67
83
 
68
84
  ~IndexIVFSpectralHash() override;
69
85
  };
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexLSH.h>
11
9
 
12
10
  #include <cstdio>
@@ -25,15 +23,13 @@ namespace faiss {
25
23
  ***************************************************************/
26
24
 
27
25
  IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
28
- : Index(d),
26
+ : IndexFlatCodes((nbits + 7) / 8, d),
29
27
  nbits(nbits),
30
28
  rotate_data(rotate_data),
31
29
  train_thresholds(train_thresholds),
32
30
  rrot(d, nbits) {
33
31
  is_trained = !train_thresholds;
34
32
 
35
- bytes_per_vec = (nbits + 7) / 8;
36
-
37
33
  if (rotate_data) {
38
34
  rrot.init(5);
39
35
  } else {
@@ -41,11 +37,7 @@ IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
41
37
  }
42
38
  }
43
39
 
44
- IndexLSH::IndexLSH()
45
- : nbits(0),
46
- bytes_per_vec(0),
47
- rotate_data(false),
48
- train_thresholds(false) {}
40
+ IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
49
41
 
50
42
  const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
51
43
  float* xt = nullptr;
@@ -106,28 +98,21 @@ void IndexLSH::train(idx_t n, const float* x) {
106
98
  is_trained = true;
107
99
  }
108
100
 
109
- void IndexLSH::add(idx_t n, const float* x) {
110
- FAISS_THROW_IF_NOT(is_trained);
111
- codes.resize((ntotal + n) * bytes_per_vec);
112
-
113
- sa_encode(n, x, &codes[ntotal * bytes_per_vec]);
114
-
115
- ntotal += n;
116
- }
117
-
118
101
  void IndexLSH::search(
119
102
  idx_t n,
120
103
  const float* x,
121
104
  idx_t k,
122
105
  float* distances,
123
- idx_t* labels) const {
106
+ idx_t* labels,
107
+ const SearchParameters* params) const {
108
+ FAISS_THROW_IF_NOT_MSG(
109
+ !params, "search params not supported for this index");
124
110
  FAISS_THROW_IF_NOT(k > 0);
125
-
126
111
  FAISS_THROW_IF_NOT(is_trained);
127
112
  const float* xt = apply_preprocess(n, x);
128
113
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
129
114
 
130
- uint8_t* qcodes = new uint8_t[n * bytes_per_vec];
115
+ uint8_t* qcodes = new uint8_t[n * code_size];
131
116
  ScopeDeleter<uint8_t> del2(qcodes);
132
117
 
133
118
  fvecs2bitvecs(xt, qcodes, nbits, n);
@@ -137,7 +122,7 @@ void IndexLSH::search(
137
122
 
138
123
  int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances};
139
124
 
140
- hammings_knn_hc(&res, qcodes, codes.data(), ntotal, bytes_per_vec, true);
125
+ hammings_knn_hc(&res, qcodes, codes.data(), ntotal, code_size, true);
141
126
 
142
127
  // convert distances to floats
143
128
  for (int i = 0; i < k * n; i++)
@@ -158,15 +143,6 @@ void IndexLSH::transfer_thresholds(LinearTransform* vt) {
158
143
  thresholds.clear();
159
144
  }
160
145
 
161
- void IndexLSH::reset() {
162
- codes.clear();
163
- ntotal = 0;
164
- }
165
-
166
- size_t IndexLSH::sa_code_size() const {
167
- return bytes_per_vec;
168
- }
169
-
170
146
  void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
171
147
  FAISS_THROW_IF_NOT(is_trained);
172
148
  const float* xt = apply_preprocess(n, x);
@@ -12,17 +12,14 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
  #include <faiss/VectorTransform.h>
17
17
 
18
18
  namespace faiss {
19
19
 
20
20
  /** The sign of each vector component is put in a binary signature */
21
- struct IndexLSH : Index {
22
- typedef unsigned char uint8_t;
23
-
21
+ struct IndexLSH : IndexFlatCodes {
24
22
  int nbits; ///< nb of bits per vector
25
- int bytes_per_vec; ///< nb of 8-bits per encoded vector
26
23
  bool rotate_data; ///< whether to apply a random rotation to input
27
24
  bool train_thresholds; ///< whether we train thresholds or use 0
28
25
 
@@ -30,9 +27,6 @@ struct IndexLSH : Index {
30
27
 
31
28
  std::vector<float> thresholds; ///< thresholds to compare with
32
29
 
33
- /// encoded dataset
34
- std::vector<uint8_t> codes;
35
-
36
30
  IndexLSH(
37
31
  idx_t d,
38
32
  int nbits,
@@ -50,16 +44,13 @@ struct IndexLSH : Index {
50
44
 
51
45
  void train(idx_t n, const float* x) override;
52
46
 
53
- void add(idx_t n, const float* x) override;
54
-
55
47
  void search(
56
48
  idx_t n,
57
49
  const float* x,
58
50
  idx_t k,
59
51
  float* distances,
60
- idx_t* labels) const override;
61
-
62
- void reset() override;
52
+ idx_t* labels,
53
+ const SearchParameters* params = nullptr) const override;
63
54
 
64
55
  /// transfer the thresholds to a pre-processing stage (and unset
65
56
  /// train_thresholds)
@@ -72,9 +63,6 @@ struct IndexLSH : Index {
72
63
  /* standalone codec interface.
73
64
  *
74
65
  * The vectors are decoded to +/- 1 (not 0, 1) */
75
-
76
- size_t sa_code_size() const override;
77
-
78
66
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
79
67
 
80
68
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -118,7 +118,13 @@ void IndexLattice::add(idx_t, const float*) {
118
118
  FAISS_THROW_MSG("not implemented");
119
119
  }
120
120
 
121
- void IndexLattice::search(idx_t, const float*, idx_t, float*, idx_t*) const {
121
+ void IndexLattice::search(
122
+ idx_t,
123
+ const float*,
124
+ idx_t,
125
+ float*,
126
+ idx_t*,
127
+ const SearchParameters*) const {
122
128
  FAISS_THROW_MSG("not implemented");
123
129
  }
124
130
 
@@ -54,7 +54,9 @@ struct IndexLattice : Index {
54
54
  const float* x,
55
55
  idx_t k,
56
56
  float* distances,
57
- idx_t* labels) const override;
57
+ idx_t* labels,
58
+ const SearchParameters* params = nullptr) const override;
59
+
58
60
  void reset() override;
59
61
  };
60
62
 
@@ -135,9 +135,10 @@ void IndexNNDescent::search(
135
135
  const float* x,
136
136
  idx_t k,
137
137
  float* distances,
138
- idx_t* labels) const
139
-
140
- {
138
+ idx_t* labels,
139
+ const SearchParameters* params) const {
140
+ FAISS_THROW_IF_NOT_MSG(
141
+ !params, "search params not supported for this index");
141
142
  FAISS_THROW_IF_NOT_MSG(
142
143
  storage,
143
144
  "Please use IndexNNDescentFlat (or variants) "
@@ -167,9 +168,7 @@ void IndexNNDescent::search(
167
168
  float* simi = distances + i * k;
168
169
  dis->set_query(x + i * d);
169
170
 
170
- maxheap_heapify(k, simi, idxi);
171
171
  nndescent.search(*dis, k, idxi, simi, vt);
172
- maxheap_reorder(k, simi, idxi);
173
172
  }
174
173
  }
175
174
  InterruptCallback::check();
@@ -53,7 +53,8 @@ struct IndexNNDescent : Index {
53
53
  const float* x,
54
54
  idx_t k,
55
55
  float* distances,
56
- idx_t* labels) const override;
56
+ idx_t* labels,
57
+ const SearchParameters* params = nullptr) const override;
57
58
 
58
59
  void reconstruct(idx_t key, float* recons) const override;
59
60