faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,145 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/utils/AlignedTable.h>
12
+
13
+ namespace faiss {
14
+
15
+ /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
16
+ *
17
+ * The codes are not stored sequentially but grouped in blocks of size bbs.
18
+ * This makes it possible to compute distances quickly with SIMD instructions.
19
+ * The trailing codes (padding codes that are added to complete the last code)
20
+ * are garbage.
21
+ *
22
+ * Implementations:
23
+ * 12: blocked loop with internal loop on Q with qbs
24
+ * 13: same with reservoir accumulator to store results
25
+ * 14: no qbs with heap accumulator
26
+ * 15: no qbs with reservoir accumulator
27
+ */
28
+
29
+ struct IndexFastScan : Index {
30
+ // implementation to select
31
+ int implem = 0;
32
+ // skip some parts of the computation (for timing)
33
+ int skip = 0;
34
+
35
+ // size of the kernel
36
+ int bbs; // set at build time
37
+ int qbs = 0; // query block size 0 = use default
38
+
39
+ // vector quantizer
40
+ size_t M;
41
+ size_t nbits;
42
+ size_t ksub;
43
+ size_t code_size;
44
+
45
+ // packed version of the codes
46
+ size_t ntotal2;
47
+ size_t M2;
48
+
49
+ AlignedTable<uint8_t> codes;
50
+
51
+ // this is for testing purposes only
52
+ // (set when initialized by IndexPQ or IndexAQ)
53
+ const uint8_t* orig_codes = nullptr;
54
+
55
+ void init_fastscan(
56
+ int d,
57
+ size_t M,
58
+ size_t nbits,
59
+ MetricType metric,
60
+ int bbs);
61
+
62
+ IndexFastScan();
63
+
64
+ void reset() override;
65
+
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels,
72
+ const SearchParameters* params = nullptr) const override;
73
+
74
+ void add(idx_t n, const float* x) override;
75
+
76
+ virtual void compute_codes(uint8_t* codes, idx_t n, const float* x)
77
+ const = 0;
78
+
79
+ virtual void compute_float_LUT(float* lut, idx_t n, const float* x)
80
+ const = 0;
81
+
82
+ // called by search function
83
+ void compute_quantized_LUT(
84
+ idx_t n,
85
+ const float* x,
86
+ uint8_t* lut,
87
+ float* normalizers) const;
88
+
89
+ template <bool is_max, class Scaler>
90
+ void search_dispatch_implem(
91
+ idx_t n,
92
+ const float* x,
93
+ idx_t k,
94
+ float* distances,
95
+ idx_t* labels,
96
+ const Scaler& scaler) const;
97
+
98
+ template <class Cfloat, class Scaler>
99
+ void search_implem_234(
100
+ idx_t n,
101
+ const float* x,
102
+ idx_t k,
103
+ float* distances,
104
+ idx_t* labels,
105
+ const Scaler& scaler) const;
106
+
107
+ template <class C, class Scaler>
108
+ void search_implem_12(
109
+ idx_t n,
110
+ const float* x,
111
+ idx_t k,
112
+ float* distances,
113
+ idx_t* labels,
114
+ int impl,
115
+ const Scaler& scaler) const;
116
+
117
+ template <class C, class Scaler>
118
+ void search_implem_14(
119
+ idx_t n,
120
+ const float* x,
121
+ idx_t k,
122
+ float* distances,
123
+ idx_t* labels,
124
+ int impl,
125
+ const Scaler& scaler) const;
126
+
127
+ void reconstruct(idx_t key, float* recons) const override;
128
+ size_t remove_ids(const IDSelector& sel) override;
129
+ void merge_from(Index& otherIndex, idx_t add_id = 0) override;
130
+ void check_compatible_for_merge(const Index& otherIndex) const override;
131
+ };
132
+
133
+ struct FastScanStats {
134
+ uint64_t t0, t1, t2, t3;
135
+ FastScanStats() {
136
+ reset();
137
+ }
138
+ void reset() {
139
+ memset(this, 0, sizeof(*this));
140
+ }
141
+ };
142
+
143
+ FAISS_API extern FastScanStats FastScan_stats;
144
+
145
+ } // namespace faiss
@@ -27,18 +27,20 @@ void IndexFlat::search(
27
27
  const float* x,
28
28
  idx_t k,
29
29
  float* distances,
30
- idx_t* labels) const {
30
+ idx_t* labels,
31
+ const SearchParameters* params) const {
32
+ IDSelector* sel = params ? params->sel : nullptr;
31
33
  FAISS_THROW_IF_NOT(k > 0);
32
34
 
33
35
  // we see the distances and labels as heaps
34
-
35
36
  if (metric_type == METRIC_INNER_PRODUCT) {
36
37
  float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
37
- knn_inner_product(x, get_xb(), d, n, ntotal, &res);
38
+ knn_inner_product(x, get_xb(), d, n, ntotal, &res, sel);
38
39
  } else if (metric_type == METRIC_L2) {
39
40
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
40
- knn_L2sqr(x, get_xb(), d, n, ntotal, &res);
41
+ knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
41
42
  } else {
43
+ FAISS_THROW_IF_NOT(!sel);
42
44
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
43
45
  knn_extra_metrics(
44
46
  x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
@@ -49,14 +51,17 @@ void IndexFlat::range_search(
49
51
  idx_t n,
50
52
  const float* x,
51
53
  float radius,
52
- RangeSearchResult* result) const {
54
+ RangeSearchResult* result,
55
+ const SearchParameters* params) const {
56
+ IDSelector* sel = params ? params->sel : nullptr;
57
+
53
58
  switch (metric_type) {
54
59
  case METRIC_INNER_PRODUCT:
55
60
  range_search_inner_product(
56
- x, get_xb(), d, n, ntotal, radius, result);
61
+ x, get_xb(), d, n, ntotal, radius, result, sel);
57
62
  break;
58
63
  case METRIC_L2:
59
- range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result);
64
+ range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result, sel);
60
65
  break;
61
66
  default:
62
67
  FAISS_THROW_MSG("metric type not supported");
@@ -83,16 +88,16 @@ void IndexFlat::compute_distance_subset(
83
88
 
84
89
  namespace {
85
90
 
86
- struct FlatL2Dis : DistanceComputer {
91
+ struct FlatL2Dis : FlatCodesDistanceComputer {
87
92
  size_t d;
88
93
  Index::idx_t nb;
89
94
  const float* q;
90
95
  const float* b;
91
96
  size_t ndis;
92
97
 
93
- float operator()(idx_t i) override {
98
+ float distance_to_code(const uint8_t* code) final {
94
99
  ndis++;
95
- return fvec_L2sqr(q, b + i * d, d);
100
+ return fvec_L2sqr(q, (float*)code, d);
96
101
  }
97
102
 
98
103
  float symmetric_dis(idx_t i, idx_t j) override {
@@ -100,7 +105,10 @@ struct FlatL2Dis : DistanceComputer {
100
105
  }
101
106
 
102
107
  explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
103
- : d(storage.d),
108
+ : FlatCodesDistanceComputer(
109
+ storage.codes.data(),
110
+ storage.code_size),
111
+ d(storage.d),
104
112
  nb(storage.ntotal),
105
113
  q(q),
106
114
  b(storage.get_xb()),
@@ -111,24 +119,27 @@ struct FlatL2Dis : DistanceComputer {
111
119
  }
112
120
  };
113
121
 
114
- struct FlatIPDis : DistanceComputer {
122
+ struct FlatIPDis : FlatCodesDistanceComputer {
115
123
  size_t d;
116
124
  Index::idx_t nb;
117
125
  const float* q;
118
126
  const float* b;
119
127
  size_t ndis;
120
128
 
121
- float operator()(idx_t i) override {
122
- ndis++;
123
- return fvec_inner_product(q, b + i * d, d);
124
- }
125
-
126
129
  float symmetric_dis(idx_t i, idx_t j) override {
127
130
  return fvec_inner_product(b + j * d, b + i * d, d);
128
131
  }
129
132
 
133
+ float distance_to_code(const uint8_t* code) final {
134
+ ndis++;
135
+ return fvec_inner_product(q, (float*)code, d);
136
+ }
137
+
130
138
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
131
- : d(storage.d),
139
+ : FlatCodesDistanceComputer(
140
+ storage.codes.data(),
141
+ storage.code_size),
142
+ d(storage.d),
132
143
  nb(storage.ntotal),
133
144
  q(q),
134
145
  b(storage.get_xb()),
@@ -141,7 +152,7 @@ struct FlatIPDis : DistanceComputer {
141
152
 
142
153
  } // namespace
143
154
 
144
- DistanceComputer* IndexFlat::get_distance_computer() const {
155
+ FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
145
156
  if (metric_type == METRIC_L2) {
146
157
  return new FlatL2Dis(*this);
147
158
  } else if (metric_type == METRIC_INNER_PRODUCT) {
@@ -202,9 +213,11 @@ void IndexFlat1D::search(
202
213
  const float* x,
203
214
  idx_t k,
204
215
  float* distances,
205
- idx_t* labels) const {
216
+ idx_t* labels,
217
+ const SearchParameters* params) const {
218
+ FAISS_THROW_IF_NOT_MSG(
219
+ !params, "search params not supported for this index");
206
220
  FAISS_THROW_IF_NOT(k > 0);
207
-
208
221
  FAISS_THROW_IF_NOT_MSG(
209
222
  perm.size() == ntotal, "Call update_permutation before search");
210
223
  const float* xb = get_xb();
@@ -25,13 +25,15 @@ struct IndexFlat : IndexFlatCodes {
25
25
  const float* x,
26
26
  idx_t k,
27
27
  float* distances,
28
- idx_t* labels) const override;
28
+ idx_t* labels,
29
+ const SearchParameters* params = nullptr) const override;
29
30
 
30
31
  void range_search(
31
32
  idx_t n,
32
33
  const float* x,
33
34
  float radius,
34
- RangeSearchResult* result) const override;
35
+ RangeSearchResult* result,
36
+ const SearchParameters* params = nullptr) const override;
35
37
 
36
38
  void reconstruct(idx_t key, float* recons) const override;
37
39
 
@@ -60,7 +62,7 @@ struct IndexFlat : IndexFlatCodes {
60
62
 
61
63
  IndexFlat() {}
62
64
 
63
- DistanceComputer* get_distance_computer() const override;
65
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
64
66
 
65
67
  /* The stanadlone codec interface (just memcopies in this case) */
66
68
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
@@ -100,7 +102,8 @@ struct IndexFlat1D : IndexFlatL2 {
100
102
  const float* x,
101
103
  idx_t k,
102
104
  float* distances,
103
- idx_t* labels) const override;
105
+ idx_t* labels,
106
+ const SearchParameters* params = nullptr) const override;
104
107
  };
105
108
 
106
109
  } // namespace faiss
@@ -8,7 +8,9 @@
8
8
  #include <faiss/IndexFlatCodes.h>
9
9
 
10
10
  #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/DistanceComputer.h>
11
12
  #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/impl/IDSelector.h>
12
14
 
13
15
  namespace faiss {
14
16
 
@@ -19,8 +21,11 @@ IndexFlatCodes::IndexFlatCodes() : code_size(0) {}
19
21
 
20
22
  void IndexFlatCodes::add(idx_t n, const float* x) {
21
23
  FAISS_THROW_IF_NOT(is_trained);
24
+ if (n == 0) {
25
+ return;
26
+ }
22
27
  codes.resize((ntotal + n) * code_size);
23
- sa_encode(n, x, &codes[ntotal * code_size]);
28
+ sa_encode(n, x, codes.data() + (ntotal * code_size));
24
29
  ntotal += n;
25
30
  }
26
31
 
@@ -64,4 +69,33 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
64
69
  reconstruct_n(key, 1, recons);
65
70
  }
66
71
 
72
+ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
73
+ const {
74
+ FAISS_THROW_MSG("not implemented");
75
+ }
76
+
77
+ void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
78
+ // minimal sanity checks
79
+ const IndexFlatCodes* other =
80
+ dynamic_cast<const IndexFlatCodes*>(&otherIndex);
81
+ FAISS_THROW_IF_NOT(other);
82
+ FAISS_THROW_IF_NOT(other->d == d);
83
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
84
+ FAISS_THROW_IF_NOT_MSG(
85
+ typeid(*this) == typeid(*other),
86
+ "can only merge indexes of the same type");
87
+ }
88
+
89
+ void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
90
+ FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatCodes index");
91
+ check_compatible_for_merge(otherIndex);
92
+ IndexFlatCodes* other = static_cast<IndexFlatCodes*>(&otherIndex);
93
+ codes.resize((ntotal + other->ntotal) * code_size);
94
+ memcpy(codes.data() + (ntotal * code_size),
95
+ other->codes.data(),
96
+ other->ntotal * code_size);
97
+ ntotal += other->ntotal;
98
+ other->reset();
99
+ }
100
+
67
101
  } // namespace faiss
@@ -10,6 +10,7 @@
10
10
  #pragma once
11
11
 
12
12
  #include <faiss/Index.h>
13
+ #include <faiss/impl/DistanceComputer.h>
13
14
  #include <vector>
14
15
 
15
16
  namespace faiss {
@@ -42,6 +43,17 @@ struct IndexFlatCodes : Index {
42
43
  * indexing structure, the semantics of this operation are
43
44
  * different from the usual ones: the new ids are shifted */
44
45
  size_t remove_ids(const IDSelector& sel) override;
46
+
47
+ /** a FlatCodesDistanceComputer offers a distance_to_code method */
48
+ virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
49
+
50
+ DistanceComputer* get_distance_computer() const override {
51
+ return get_FlatCodesDistanceComputer();
52
+ }
53
+
54
+ void check_compatible_for_merge(const Index& otherIndex) const override;
55
+
56
+ virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
45
57
  };
46
58
 
47
59
  } // namespace faiss
@@ -202,7 +202,10 @@ void hnsw_add_vertices(
202
202
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
203
203
  size_t counter = 0;
204
204
 
205
- #pragma omp for schedule(dynamic)
205
+ // here we should do schedule(dynamic) but this segfaults for
206
+ // some versions of LLVM. The performance impact should not be
207
+ // too large when (i1 - i0) / num_threads >> 1
208
+ #pragma omp for schedule(static)
206
209
  for (int i = i0; i < i1; i++) {
207
210
  storage_idx_t pt_id = order[i];
208
211
  dis->set_query(x + (pt_id - n0) * d);
@@ -219,7 +222,6 @@ void hnsw_add_vertices(
219
222
  printf(" %d / %d\r", i - i0, i1 - i0);
220
223
  fflush(stdout);
221
224
  }
222
-
223
225
  if (counter % check_period == 0) {
224
226
  if (InterruptCallback::is_interrupted()) {
225
227
  interrupt = true;
@@ -284,18 +286,24 @@ void IndexHNSW::search(
284
286
  const float* x,
285
287
  idx_t k,
286
288
  float* distances,
287
- idx_t* labels) const
288
-
289
- {
289
+ idx_t* labels,
290
+ const SearchParameters* params_in) const {
290
291
  FAISS_THROW_IF_NOT(k > 0);
291
-
292
292
  FAISS_THROW_IF_NOT_MSG(
293
293
  storage,
294
294
  "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
295
+ const SearchParametersHNSW* params = nullptr;
296
+
297
+ int efSearch = hnsw.efSearch;
298
+ if (params_in) {
299
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
300
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
301
+ efSearch = params->efSearch;
302
+ }
295
303
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
296
304
 
297
- idx_t check_period = InterruptCallback::get_period_hint(
298
- hnsw.max_level * d * hnsw.efSearch);
305
+ idx_t check_period =
306
+ InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
299
307
 
300
308
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
301
309
  idx_t i1 = std::min(i0 + check_period, n);
@@ -314,7 +322,7 @@ void IndexHNSW::search(
314
322
  dis->set_query(x + i * d);
315
323
 
316
324
  maxheap_heapify(k, simi, idxi);
317
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt);
325
+ HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
318
326
  n1 += stats.n1;
319
327
  n2 += stats.n2;
320
328
  n3 += stats.n3;
@@ -423,16 +431,15 @@ void IndexHNSW::search_level_0(
423
431
  FAISS_THROW_IF_NOT(nprobe > 0);
424
432
 
425
433
  storage_idx_t ntotal = hnsw.levels.size();
426
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
427
434
 
428
435
  #pragma omp parallel
429
436
  {
430
- DistanceComputer* qdis = storage_distance_computer(storage);
431
- ScopeDeleter1<DistanceComputer> del(qdis);
432
-
437
+ std::unique_ptr<DistanceComputer> qdis(
438
+ storage_distance_computer(storage));
439
+ HNSWStats search_stats;
433
440
  VisitedTable vt(ntotal);
434
441
 
435
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
442
+ #pragma omp for
436
443
  for (idx_t i = 0; i < n; i++) {
437
444
  idx_t* idxi = labels + i * k;
438
445
  float* simi = distances + i * k;
@@ -440,69 +447,24 @@ void IndexHNSW::search_level_0(
440
447
  qdis->set_query(x + i * d);
441
448
  maxheap_heapify(k, simi, idxi);
442
449
 
443
- if (search_type == 1) {
444
- int nres = 0;
445
-
446
- for (int j = 0; j < nprobe; j++) {
447
- storage_idx_t cj = nearest[i * nprobe + j];
448
-
449
- if (cj < 0)
450
- break;
451
-
452
- if (vt.get(cj))
453
- continue;
454
-
455
- int candidates_size = std::max(hnsw.efSearch, int(k));
456
- MinimaxHeap candidates(candidates_size);
457
-
458
- candidates.push(cj, nearest_d[i * nprobe + j]);
459
-
460
- HNSWStats search_stats;
461
- nres = hnsw.search_from_candidates(
462
- *qdis,
463
- k,
464
- idxi,
465
- simi,
466
- candidates,
467
- vt,
468
- search_stats,
469
- 0,
470
- nres);
471
- n1 += search_stats.n1;
472
- n2 += search_stats.n2;
473
- n3 += search_stats.n3;
474
- ndis += search_stats.ndis;
475
- nreorder += search_stats.nreorder;
476
- }
477
- } else if (search_type == 2) {
478
- int candidates_size = std::max(hnsw.efSearch, int(k));
479
- candidates_size = std::max(candidates_size, nprobe);
480
-
481
- MinimaxHeap candidates(candidates_size);
482
- for (int j = 0; j < nprobe; j++) {
483
- storage_idx_t cj = nearest[i * nprobe + j];
484
-
485
- if (cj < 0)
486
- break;
487
- candidates.push(cj, nearest_d[i * nprobe + j]);
488
- }
450
+ hnsw.search_level_0(
451
+ *qdis.get(),
452
+ k,
453
+ idxi,
454
+ simi,
455
+ nprobe,
456
+ nearest + i * nprobe,
457
+ nearest_d + i * nprobe,
458
+ search_type,
459
+ search_stats,
460
+ vt);
489
461
 
490
- HNSWStats search_stats;
491
- hnsw.search_from_candidates(
492
- *qdis, k, idxi, simi, candidates, vt, search_stats, 0);
493
- n1 += search_stats.n1;
494
- n2 += search_stats.n2;
495
- n3 += search_stats.n3;
496
- ndis += search_stats.ndis;
497
- nreorder += search_stats.nreorder;
498
- }
499
462
  vt.advance();
500
-
501
463
  maxheap_reorder(k, simi, idxi);
502
464
  }
465
+ #pragma omp critical
466
+ { hnsw_stats.combine(search_stats); }
503
467
  }
504
-
505
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
506
468
  }
507
469
 
508
470
  void IndexHNSW::init_level_0_from_knngraph(
@@ -1035,8 +997,11 @@ void IndexHNSW2Level::search(
1035
997
  const float* x,
1036
998
  idx_t k,
1037
999
  float* distances,
1038
- idx_t* labels) const {
1000
+ idx_t* labels,
1001
+ const SearchParameters* params) const {
1039
1002
  FAISS_THROW_IF_NOT(k > 0);
1003
+ FAISS_THROW_IF_NOT_MSG(
1004
+ !params, "search params not supported for this index");
1040
1005
 
1041
1006
  if (dynamic_cast<const Index2Layer*>(storage)) {
1042
1007
  IndexHNSW::search(n, x, k, distances, labels);
@@ -1095,74 +1060,37 @@ void IndexHNSW2Level::search(
1095
1060
  }
1096
1061
 
1097
1062
  candidates.clear();
1098
- // copy the upper_beam elements to candidates list
1099
-
1100
- int search_policy = 2;
1101
-
1102
- if (search_policy == 1) {
1103
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1104
- if (idxi[j] < 0)
1105
- break;
1106
- candidates.push(idxi[j], simi[j]);
1107
- // search_from_candidates adds them back
1108
- idxi[j] = -1;
1109
- simi[j] = HUGE_VAL;
1110
- }
1111
1063
 
1112
- // reorder from sorted to heap
1113
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1114
-
1115
- HNSWStats search_stats;
1116
- hnsw.search_from_candidates(
1117
- *dis,
1118
- k,
1119
- idxi,
1120
- simi,
1121
- candidates,
1122
- vt,
1123
- search_stats,
1124
- 0,
1125
- k);
1126
- n1 += search_stats.n1;
1127
- n2 += search_stats.n2;
1128
- n3 += search_stats.n3;
1129
- ndis += search_stats.ndis;
1130
- nreorder += search_stats.nreorder;
1131
-
1132
- vt.advance();
1133
-
1134
- } else if (search_policy == 2) {
1135
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1136
- if (idxi[j] < 0)
1137
- break;
1138
- candidates.push(idxi[j], simi[j]);
1139
- }
1140
-
1141
- // reorder from sorted to heap
1142
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1143
-
1144
- HNSWStats search_stats;
1145
- search_from_candidates_2(
1146
- hnsw,
1147
- *dis,
1148
- k,
1149
- idxi,
1150
- simi,
1151
- candidates,
1152
- vt,
1153
- search_stats,
1154
- 0,
1155
- k);
1156
- n1 += search_stats.n1;
1157
- n2 += search_stats.n2;
1158
- n3 += search_stats.n3;
1159
- ndis += search_stats.ndis;
1160
- nreorder += search_stats.nreorder;
1161
-
1162
- vt.advance();
1163
- vt.advance();
1064
+ for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1065
+ if (idxi[j] < 0)
1066
+ break;
1067
+ candidates.push(idxi[j], simi[j]);
1164
1068
  }
1165
1069
 
1070
+ // reorder from sorted to heap
1071
+ maxheap_heapify(k, simi, idxi, simi, idxi, k);
1072
+
1073
+ HNSWStats search_stats;
1074
+ search_from_candidates_2(
1075
+ hnsw,
1076
+ *dis,
1077
+ k,
1078
+ idxi,
1079
+ simi,
1080
+ candidates,
1081
+ vt,
1082
+ search_stats,
1083
+ 0,
1084
+ k);
1085
+ n1 += search_stats.n1;
1086
+ n2 += search_stats.n2;
1087
+ n3 += search_stats.n3;
1088
+ ndis += search_stats.ndis;
1089
+ nreorder += search_stats.nreorder;
1090
+
1091
+ vt.advance();
1092
+ vt.advance();
1093
+
1166
1094
  maxheap_reorder(k, simi, idxi);
1167
1095
  }
1168
1096
  }
@@ -96,7 +96,8 @@ struct IndexHNSW : Index {
96
96
  const float* x,
97
97
  idx_t k,
98
98
  float* distances,
99
- idx_t* labels) const override;
99
+ idx_t* labels,
100
+ const SearchParameters* params = nullptr) const override;
100
101
 
101
102
  void reconstruct(idx_t key, float* recons) const override;
102
103
 
@@ -180,7 +181,8 @@ struct IndexHNSW2Level : IndexHNSW {
180
181
  const float* x,
181
182
  idx_t k,
182
183
  float* distances,
183
- idx_t* labels) const override;
184
+ idx_t* labels,
185
+ const SearchParameters* params = nullptr) const override;
184
186
  };
185
187
 
186
188
  } // namespace faiss