faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -0,0 +1,145 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/utils/AlignedTable.h>
12
+
13
+ namespace faiss {
14
+
15
+ /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
16
+ *
17
+ * The codes are not stored sequentially but grouped in blocks of size bbs.
18
+ * This makes it possible to compute distances quickly with SIMD instructions.
19
+ * The trailing codes (padding codes that are added to complete the last code)
20
+ * are garbage.
21
+ *
22
+ * Implementations:
23
+ * 12: blocked loop with internal loop on Q with qbs
24
+ * 13: same with reservoir accumulator to store results
25
+ * 14: no qbs with heap accumulator
26
+ * 15: no qbs with reservoir accumulator
27
+ */
28
+
29
+ struct IndexFastScan : Index {
30
+ // implementation to select
31
+ int implem = 0;
32
+ // skip some parts of the computation (for timing)
33
+ int skip = 0;
34
+
35
+ // size of the kernel
36
+ int bbs; // set at build time
37
+ int qbs = 0; // query block size 0 = use default
38
+
39
+ // vector quantizer
40
+ size_t M;
41
+ size_t nbits;
42
+ size_t ksub;
43
+ size_t code_size;
44
+
45
+ // packed version of the codes
46
+ size_t ntotal2;
47
+ size_t M2;
48
+
49
+ AlignedTable<uint8_t> codes;
50
+
51
+ // this is for testing purposes only
52
+ // (set when initialized by IndexPQ or IndexAQ)
53
+ const uint8_t* orig_codes = nullptr;
54
+
55
+ void init_fastscan(
56
+ int d,
57
+ size_t M,
58
+ size_t nbits,
59
+ MetricType metric,
60
+ int bbs);
61
+
62
+ IndexFastScan();
63
+
64
+ void reset() override;
65
+
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels,
72
+ const SearchParameters* params = nullptr) const override;
73
+
74
+ void add(idx_t n, const float* x) override;
75
+
76
+ virtual void compute_codes(uint8_t* codes, idx_t n, const float* x)
77
+ const = 0;
78
+
79
+ virtual void compute_float_LUT(float* lut, idx_t n, const float* x)
80
+ const = 0;
81
+
82
+ // called by search function
83
+ void compute_quantized_LUT(
84
+ idx_t n,
85
+ const float* x,
86
+ uint8_t* lut,
87
+ float* normalizers) const;
88
+
89
+ template <bool is_max, class Scaler>
90
+ void search_dispatch_implem(
91
+ idx_t n,
92
+ const float* x,
93
+ idx_t k,
94
+ float* distances,
95
+ idx_t* labels,
96
+ const Scaler& scaler) const;
97
+
98
+ template <class Cfloat, class Scaler>
99
+ void search_implem_234(
100
+ idx_t n,
101
+ const float* x,
102
+ idx_t k,
103
+ float* distances,
104
+ idx_t* labels,
105
+ const Scaler& scaler) const;
106
+
107
+ template <class C, class Scaler>
108
+ void search_implem_12(
109
+ idx_t n,
110
+ const float* x,
111
+ idx_t k,
112
+ float* distances,
113
+ idx_t* labels,
114
+ int impl,
115
+ const Scaler& scaler) const;
116
+
117
+ template <class C, class Scaler>
118
+ void search_implem_14(
119
+ idx_t n,
120
+ const float* x,
121
+ idx_t k,
122
+ float* distances,
123
+ idx_t* labels,
124
+ int impl,
125
+ const Scaler& scaler) const;
126
+
127
+ void reconstruct(idx_t key, float* recons) const override;
128
+ size_t remove_ids(const IDSelector& sel) override;
129
+ void merge_from(Index& otherIndex, idx_t add_id = 0) override;
130
+ void check_compatible_for_merge(const Index& otherIndex) const override;
131
+ };
132
+
133
+ struct FastScanStats {
134
+ uint64_t t0, t1, t2, t3;
135
+ FastScanStats() {
136
+ reset();
137
+ }
138
+ void reset() {
139
+ memset(this, 0, sizeof(*this));
140
+ }
141
+ };
142
+
143
+ FAISS_API extern FastScanStats FastScan_stats;
144
+
145
+ } // namespace faiss
@@ -19,38 +19,31 @@
19
19
 
20
20
  namespace faiss {
21
21
 
22
- IndexFlat::IndexFlat(idx_t d, MetricType metric) : Index(d, metric) {}
23
-
24
- void IndexFlat::add(idx_t n, const float* x) {
25
- xb.insert(xb.end(), x, x + n * d);
26
- ntotal += n;
27
- }
28
-
29
- void IndexFlat::reset() {
30
- xb.clear();
31
- ntotal = 0;
32
- }
22
+ IndexFlat::IndexFlat(idx_t d, MetricType metric)
23
+ : IndexFlatCodes(sizeof(float) * d, d, metric) {}
33
24
 
34
25
  void IndexFlat::search(
35
26
  idx_t n,
36
27
  const float* x,
37
28
  idx_t k,
38
29
  float* distances,
39
- idx_t* labels) const {
30
+ idx_t* labels,
31
+ const SearchParameters* params) const {
32
+ IDSelector* sel = params ? params->sel : nullptr;
40
33
  FAISS_THROW_IF_NOT(k > 0);
41
34
 
42
35
  // we see the distances and labels as heaps
43
-
44
36
  if (metric_type == METRIC_INNER_PRODUCT) {
45
37
  float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
46
- knn_inner_product(x, xb.data(), d, n, ntotal, &res);
38
+ knn_inner_product(x, get_xb(), d, n, ntotal, &res, sel);
47
39
  } else if (metric_type == METRIC_L2) {
48
40
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
49
- knn_L2sqr(x, xb.data(), d, n, ntotal, &res);
41
+ knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
50
42
  } else {
43
+ FAISS_THROW_IF_NOT(!sel);
51
44
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
52
45
  knn_extra_metrics(
53
- x, xb.data(), d, n, ntotal, metric_type, metric_arg, &res);
46
+ x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
54
47
  }
55
48
  }
56
49
 
@@ -58,14 +51,17 @@ void IndexFlat::range_search(
58
51
  idx_t n,
59
52
  const float* x,
60
53
  float radius,
61
- RangeSearchResult* result) const {
54
+ RangeSearchResult* result,
55
+ const SearchParameters* params) const {
56
+ IDSelector* sel = params ? params->sel : nullptr;
57
+
62
58
  switch (metric_type) {
63
59
  case METRIC_INNER_PRODUCT:
64
60
  range_search_inner_product(
65
- x, xb.data(), d, n, ntotal, radius, result);
61
+ x, get_xb(), d, n, ntotal, radius, result, sel);
66
62
  break;
67
63
  case METRIC_L2:
68
- range_search_L2sqr(x, xb.data(), d, n, ntotal, radius, result);
64
+ range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result, sel);
69
65
  break;
70
66
  default:
71
67
  FAISS_THROW_MSG("metric type not supported");
@@ -80,49 +76,28 @@ void IndexFlat::compute_distance_subset(
80
76
  const idx_t* labels) const {
81
77
  switch (metric_type) {
82
78
  case METRIC_INNER_PRODUCT:
83
- fvec_inner_products_by_idx(
84
- distances, x, xb.data(), labels, d, n, k);
79
+ fvec_inner_products_by_idx(distances, x, get_xb(), labels, d, n, k);
85
80
  break;
86
81
  case METRIC_L2:
87
- fvec_L2sqr_by_idx(distances, x, xb.data(), labels, d, n, k);
82
+ fvec_L2sqr_by_idx(distances, x, get_xb(), labels, d, n, k);
88
83
  break;
89
84
  default:
90
85
  FAISS_THROW_MSG("metric type not supported");
91
86
  }
92
87
  }
93
88
 
94
- size_t IndexFlat::remove_ids(const IDSelector& sel) {
95
- idx_t j = 0;
96
- for (idx_t i = 0; i < ntotal; i++) {
97
- if (sel.is_member(i)) {
98
- // should be removed
99
- } else {
100
- if (i > j) {
101
- memmove(&xb[d * j], &xb[d * i], sizeof(xb[0]) * d);
102
- }
103
- j++;
104
- }
105
- }
106
- size_t nremove = ntotal - j;
107
- if (nremove > 0) {
108
- ntotal = j;
109
- xb.resize(ntotal * d);
110
- }
111
- return nremove;
112
- }
113
-
114
89
  namespace {
115
90
 
116
- struct FlatL2Dis : DistanceComputer {
91
+ struct FlatL2Dis : FlatCodesDistanceComputer {
117
92
  size_t d;
118
93
  Index::idx_t nb;
119
94
  const float* q;
120
95
  const float* b;
121
96
  size_t ndis;
122
97
 
123
- float operator()(idx_t i) override {
98
+ float distance_to_code(const uint8_t* code) final {
124
99
  ndis++;
125
- return fvec_L2sqr(q, b + i * d, d);
100
+ return fvec_L2sqr(q, (float*)code, d);
126
101
  }
127
102
 
128
103
  float symmetric_dis(idx_t i, idx_t j) override {
@@ -130,10 +105,13 @@ struct FlatL2Dis : DistanceComputer {
130
105
  }
131
106
 
132
107
  explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
133
- : d(storage.d),
108
+ : FlatCodesDistanceComputer(
109
+ storage.codes.data(),
110
+ storage.code_size),
111
+ d(storage.d),
134
112
  nb(storage.ntotal),
135
113
  q(q),
136
- b(storage.xb.data()),
114
+ b(storage.get_xb()),
137
115
  ndis(0) {}
138
116
 
139
117
  void set_query(const float* x) override {
@@ -141,27 +119,30 @@ struct FlatL2Dis : DistanceComputer {
141
119
  }
142
120
  };
143
121
 
144
- struct FlatIPDis : DistanceComputer {
122
+ struct FlatIPDis : FlatCodesDistanceComputer {
145
123
  size_t d;
146
124
  Index::idx_t nb;
147
125
  const float* q;
148
126
  const float* b;
149
127
  size_t ndis;
150
128
 
151
- float operator()(idx_t i) override {
152
- ndis++;
153
- return fvec_inner_product(q, b + i * d, d);
154
- }
155
-
156
129
  float symmetric_dis(idx_t i, idx_t j) override {
157
130
  return fvec_inner_product(b + j * d, b + i * d, d);
158
131
  }
159
132
 
133
+ float distance_to_code(const uint8_t* code) final {
134
+ ndis++;
135
+ return fvec_inner_product(q, (float*)code, d);
136
+ }
137
+
160
138
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
161
- : d(storage.d),
139
+ : FlatCodesDistanceComputer(
140
+ storage.codes.data(),
141
+ storage.code_size),
142
+ d(storage.d),
162
143
  nb(storage.ntotal),
163
144
  q(q),
164
- b(storage.xb.data()),
145
+ b(storage.get_xb()),
165
146
  ndis(0) {}
166
147
 
167
148
  void set_query(const float* x) override {
@@ -171,32 +152,31 @@ struct FlatIPDis : DistanceComputer {
171
152
 
172
153
  } // namespace
173
154
 
174
- DistanceComputer* IndexFlat::get_distance_computer() const {
155
+ FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
175
156
  if (metric_type == METRIC_L2) {
176
157
  return new FlatL2Dis(*this);
177
158
  } else if (metric_type == METRIC_INNER_PRODUCT) {
178
159
  return new FlatIPDis(*this);
179
160
  } else {
180
161
  return get_extra_distance_computer(
181
- d, metric_type, metric_arg, ntotal, xb.data());
162
+ d, metric_type, metric_arg, ntotal, get_xb());
182
163
  }
183
164
  }
184
165
 
185
166
  void IndexFlat::reconstruct(idx_t key, float* recons) const {
186
- memcpy(recons, &(xb[key * d]), sizeof(*recons) * d);
187
- }
188
-
189
- /* The standalone codec interface */
190
- size_t IndexFlat::sa_code_size() const {
191
- return sizeof(float) * d;
167
+ memcpy(recons, &(codes[key * code_size]), code_size);
192
168
  }
193
169
 
194
170
  void IndexFlat::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
195
- memcpy(bytes, x, sizeof(float) * d * n);
171
+ if (n > 0) {
172
+ memcpy(bytes, x, sizeof(float) * d * n);
173
+ }
196
174
  }
197
175
 
198
176
  void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
199
- memcpy(x, bytes, sizeof(float) * d * n);
177
+ if (n > 0) {
178
+ memcpy(x, bytes, sizeof(float) * d * n);
179
+ }
200
180
  }
201
181
 
202
182
  /***************************************************
@@ -211,9 +191,9 @@ IndexFlat1D::IndexFlat1D(bool continuous_update)
211
191
  void IndexFlat1D::update_permutation() {
212
192
  perm.resize(ntotal);
213
193
  if (ntotal < 1000000) {
214
- fvec_argsort(ntotal, xb.data(), (size_t*)perm.data());
194
+ fvec_argsort(ntotal, get_xb(), (size_t*)perm.data());
215
195
  } else {
216
- fvec_argsort_parallel(ntotal, xb.data(), (size_t*)perm.data());
196
+ fvec_argsort_parallel(ntotal, get_xb(), (size_t*)perm.data());
217
197
  }
218
198
  }
219
199
 
@@ -233,11 +213,14 @@ void IndexFlat1D::search(
233
213
  const float* x,
234
214
  idx_t k,
235
215
  float* distances,
236
- idx_t* labels) const {
216
+ idx_t* labels,
217
+ const SearchParameters* params) const {
218
+ FAISS_THROW_IF_NOT_MSG(
219
+ !params, "search params not supported for this index");
237
220
  FAISS_THROW_IF_NOT(k > 0);
238
-
239
221
  FAISS_THROW_IF_NOT_MSG(
240
222
  perm.size() == ntotal, "Call update_permutation before search");
223
+ const float* xb = get_xb();
241
224
 
242
225
  #pragma omp parallel for
243
226
  for (idx_t i = 0; i < n; i++) {
@@ -12,33 +12,28 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
 
17
17
  namespace faiss {
18
18
 
19
19
  /** Index that stores the full vectors and performs exhaustive search */
20
- struct IndexFlat : Index {
21
- /// database vectors, size ntotal * d
22
- std::vector<float> xb;
23
-
20
+ struct IndexFlat : IndexFlatCodes {
24
21
  explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
25
22
 
26
- void add(idx_t n, const float* x) override;
27
-
28
- void reset() override;
29
-
30
23
  void search(
31
24
  idx_t n,
32
25
  const float* x,
33
26
  idx_t k,
34
27
  float* distances,
35
- idx_t* labels) const override;
28
+ idx_t* labels,
29
+ const SearchParameters* params = nullptr) const override;
36
30
 
37
31
  void range_search(
38
32
  idx_t n,
39
33
  const float* x,
40
34
  float radius,
41
- RangeSearchResult* result) const override;
35
+ RangeSearchResult* result,
36
+ const SearchParameters* params = nullptr) const override;
42
37
 
43
38
  void reconstruct(idx_t key, float* recons) const override;
44
39
 
@@ -57,18 +52,19 @@ struct IndexFlat : Index {
57
52
  float* distances,
58
53
  const idx_t* labels) const;
59
54
 
60
- /** remove some ids. NB that Because of the structure of the
61
- * indexing structure, the semantics of this operation are
62
- * different from the usual ones: the new ids are shifted */
63
- size_t remove_ids(const IDSelector& sel) override;
55
+ // get pointer to the floating point data
56
+ float* get_xb() {
57
+ return (float*)codes.data();
58
+ }
59
+ const float* get_xb() const {
60
+ return (const float*)codes.data();
61
+ }
64
62
 
65
63
  IndexFlat() {}
66
64
 
67
- DistanceComputer* get_distance_computer() const override;
65
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
68
66
 
69
67
  /* The stanadlone codec interface (just memcopies in this case) */
70
- size_t sa_code_size() const override;
71
-
72
68
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
73
69
 
74
70
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -106,7 +102,8 @@ struct IndexFlat1D : IndexFlatL2 {
106
102
  const float* x,
107
103
  idx_t k,
108
104
  float* distances,
109
- idx_t* labels) const override;
105
+ idx_t* labels,
106
+ const SearchParameters* params = nullptr) const override;
110
107
  };
111
108
 
112
109
  } // namespace faiss
@@ -0,0 +1,101 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexFlatCodes.h>
9
+
10
+ #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/DistanceComputer.h>
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/impl/IDSelector.h>
14
+
15
+ namespace faiss {
16
+
17
+ IndexFlatCodes::IndexFlatCodes(size_t code_size, idx_t d, MetricType metric)
18
+ : Index(d, metric), code_size(code_size) {}
19
+
20
+ IndexFlatCodes::IndexFlatCodes() : code_size(0) {}
21
+
22
+ void IndexFlatCodes::add(idx_t n, const float* x) {
23
+ FAISS_THROW_IF_NOT(is_trained);
24
+ if (n == 0) {
25
+ return;
26
+ }
27
+ codes.resize((ntotal + n) * code_size);
28
+ sa_encode(n, x, codes.data() + (ntotal * code_size));
29
+ ntotal += n;
30
+ }
31
+
32
+ void IndexFlatCodes::reset() {
33
+ codes.clear();
34
+ ntotal = 0;
35
+ }
36
+
37
+ size_t IndexFlatCodes::sa_code_size() const {
38
+ return code_size;
39
+ }
40
+
41
+ size_t IndexFlatCodes::remove_ids(const IDSelector& sel) {
42
+ idx_t j = 0;
43
+ for (idx_t i = 0; i < ntotal; i++) {
44
+ if (sel.is_member(i)) {
45
+ // should be removed
46
+ } else {
47
+ if (i > j) {
48
+ memmove(&codes[code_size * j],
49
+ &codes[code_size * i],
50
+ code_size);
51
+ }
52
+ j++;
53
+ }
54
+ }
55
+ size_t nremove = ntotal - j;
56
+ if (nremove > 0) {
57
+ ntotal = j;
58
+ codes.resize(ntotal * code_size);
59
+ }
60
+ return nremove;
61
+ }
62
+
63
+ void IndexFlatCodes::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
64
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
65
+ sa_decode(ni, codes.data() + i0 * code_size, recons);
66
+ }
67
+
68
+ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
69
+ reconstruct_n(key, 1, recons);
70
+ }
71
+
72
+ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
73
+ const {
74
+ FAISS_THROW_MSG("not implemented");
75
+ }
76
+
77
+ void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
78
+ // minimal sanity checks
79
+ const IndexFlatCodes* other =
80
+ dynamic_cast<const IndexFlatCodes*>(&otherIndex);
81
+ FAISS_THROW_IF_NOT(other);
82
+ FAISS_THROW_IF_NOT(other->d == d);
83
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
84
+ FAISS_THROW_IF_NOT_MSG(
85
+ typeid(*this) == typeid(*other),
86
+ "can only merge indexes of the same type");
87
+ }
88
+
89
+ void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
90
+ FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatCodes index");
91
+ check_compatible_for_merge(otherIndex);
92
+ IndexFlatCodes* other = static_cast<IndexFlatCodes*>(&otherIndex);
93
+ codes.resize((ntotal + other->ntotal) * code_size);
94
+ memcpy(codes.data() + (ntotal * code_size),
95
+ other->codes.data(),
96
+ other->ntotal * code_size);
97
+ ntotal += other->ntotal;
98
+ other->reset();
99
+ }
100
+
101
+ } // namespace faiss
@@ -0,0 +1,59 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <faiss/Index.h>
13
+ #include <faiss/impl/DistanceComputer.h>
14
+ #include <vector>
15
+
16
+ namespace faiss {
17
+
18
+ /** Index that encodes all vectors as fixed-size codes (size code_size). Storage
19
+ * is in the codes vector */
20
+ struct IndexFlatCodes : Index {
21
+ size_t code_size;
22
+
23
+ /// encoded dataset, size ntotal * code_size
24
+ std::vector<uint8_t> codes;
25
+
26
+ IndexFlatCodes();
27
+
28
+ IndexFlatCodes(size_t code_size, idx_t d, MetricType metric = METRIC_L2);
29
+
30
+ /// default add uses sa_encode
31
+ void add(idx_t n, const float* x) override;
32
+
33
+ void reset() override;
34
+
35
+ /// reconstruction using the codec interface
36
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
37
+
38
+ void reconstruct(idx_t key, float* recons) const override;
39
+
40
+ size_t sa_code_size() const override;
41
+
42
+ /** remove some ids. NB that Because of the structure of the
43
+ * indexing structure, the semantics of this operation are
44
+ * different from the usual ones: the new ids are shifted */
45
+ size_t remove_ids(const IDSelector& sel) override;
46
+
47
+ /** a FlatCodesDistanceComputer offers a distance_to_code method */
48
+ virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
49
+
50
+ DistanceComputer* get_distance_computer() const override {
51
+ return get_FlatCodesDistanceComputer();
52
+ }
53
+
54
+ void check_compatible_for_merge(const Index& otherIndex) const override;
55
+
56
+ virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
57
+ };
58
+
59
+ } // namespace faiss