faiss 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -19,6 +19,9 @@ Index* index_factory(
19
19
  const char* description,
20
20
  MetricType metric = METRIC_L2);
21
21
 
22
+ /// set to > 0 to get more logs from index_factory
23
+ FAISS_API extern int index_factory_verbose;
24
+
22
25
  IndexBinary* index_binary_factory(int d, const char* description);
23
26
 
24
27
  } // namespace faiss
@@ -105,8 +105,9 @@ void exhaustive_inner_product_seq(
105
105
  size_t ny,
106
106
  ResultHandler& res) {
107
107
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
108
+ int nt = std::min(int(nx), omp_get_max_threads());
108
109
 
109
- #pragma omp parallel
110
+ #pragma omp parallel num_threads(nt)
110
111
  {
111
112
  SingleResultHandler resi(res);
112
113
  #pragma omp for
@@ -135,8 +136,9 @@ void exhaustive_L2sqr_seq(
135
136
  size_t ny,
136
137
  ResultHandler& res) {
137
138
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
+ int nt = std::min(int(nx), omp_get_max_threads());
138
140
 
139
- #pragma omp parallel
141
+ #pragma omp parallel num_threads(nt)
140
142
  {
141
143
  SingleResultHandler resi(res);
142
144
  #pragma omp for
@@ -40,7 +40,7 @@ float fvec_Linf(const float* x, const float* y, size_t d);
40
40
  * @param nq nb of query vectors
41
41
  * @param nb nb of database vectors
42
42
  * @param xq query vectors (size nq * d)
43
- * @param xb database vectros (size nb * d)
43
+ * @param xb database vectors (size nb * d)
44
44
  * @param dis output distances (size nq * nb)
45
45
  * @param ldq,ldb, ldd strides for the matrices
46
46
  */
@@ -63,7 +63,7 @@ void fvec_inner_products_ny(
63
63
  size_t d,
64
64
  size_t ny);
65
65
 
66
- /* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
66
+ /* compute ny square L2 distance between x and a set of contiguous y vectors */
67
67
  void fvec_L2sqr_ny(
68
68
  float* dis,
69
69
  const float* x,
@@ -87,7 +87,7 @@ void fvec_norms_L2sqr(float* norms, const float* x, size_t d, size_t nx);
87
87
  /* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
88
88
  void fvec_renorm_L2(size_t d, size_t nx, float* x);
89
89
 
90
- /* This function exists because the Torch counterpart is extremly slow
90
+ /* This function exists because the Torch counterpart is extremely slow
91
91
  (not multi-threaded + unexpected overhead even in single thread).
92
92
  It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
93
93
  void inner_product_to_L2sqr(
@@ -97,6 +97,39 @@ void inner_product_to_L2sqr(
97
97
  size_t n1,
98
98
  size_t n2);
99
99
 
100
+ /*********************************************************
101
+ * Vector to vector functions
102
+ *********************************************************/
103
+
104
+ /** compute c := a + b for vectors
105
+ *
106
+ * c and a can overlap, c and b can overlap
107
+ *
108
+ * @param a size d
109
+ * @param b size d
110
+ * @param c size d
111
+ */
112
+ void fvec_add(size_t d, const float* a, const float* b, float* c);
113
+
114
+ /** compute c := a + b for a, c vectors and b a scalar
115
+ *
116
+ * c and a can overlap
117
+ *
118
+ * @param a size d
119
+ * @param c size d
120
+ */
121
+ void fvec_add(size_t d, const float* a, float b, float* c);
122
+
123
+ /** compute c := a - b for vectors
124
+ *
125
+ * c and a can overlap, c and b can overlap
126
+ *
127
+ * @param a size d
128
+ * @param b size d
129
+ * @param c size d
130
+ */
131
+ void fvec_sub(size_t d, const float* a, const float* b, float* c);
132
+
100
133
  /***************************************************************************
101
134
  * Compute a subset of distances
102
135
  ***************************************************************************/
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
+ #include <algorithm>
12
13
  #include <cassert>
13
14
  #include <cmath>
14
15
  #include <cstdio>
@@ -973,4 +974,53 @@ void compute_PQ_dis_tables_dsub2(
973
974
  }
974
975
  }
975
976
 
977
+ /*********************************************************
978
+ * Vector to vector functions
979
+ *********************************************************/
980
+
981
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
982
+ size_t i;
983
+ for (i = 0; i + 7 < d; i += 8) {
984
+ simd8float32 ci, ai, bi;
985
+ ai.loadu(a + i);
986
+ bi.loadu(b + i);
987
+ ci = ai - bi;
988
+ ci.storeu(c + i);
989
+ }
990
+ // finish non-multiple of 8 remainder
991
+ for (; i < d; i++) {
992
+ c[i] = a[i] - b[i];
993
+ }
994
+ }
995
+
996
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
997
+ size_t i;
998
+ for (i = 0; i + 7 < d; i += 8) {
999
+ simd8float32 ci, ai, bi;
1000
+ ai.loadu(a + i);
1001
+ bi.loadu(b + i);
1002
+ ci = ai + bi;
1003
+ ci.storeu(c + i);
1004
+ }
1005
+ // finish non-multiple of 8 remainder
1006
+ for (; i < d; i++) {
1007
+ c[i] = a[i] + b[i];
1008
+ }
1009
+ }
1010
+
1011
+ void fvec_add(size_t d, const float* a, float b, float* c) {
1012
+ size_t i;
1013
+ simd8float32 bv(b);
1014
+ for (i = 0; i + 7 < d; i += 8) {
1015
+ simd8float32 ci, ai, bi;
1016
+ ai.loadu(a + i);
1017
+ ci = ai + bv;
1018
+ ci.storeu(c + i);
1019
+ }
1020
+ // finish non-multiple of 8 remainder
1021
+ for (; i < d; i++) {
1022
+ c[i] = a[i] + b;
1023
+ }
1024
+ }
1025
+
976
1026
  } // namespace faiss
@@ -80,7 +80,7 @@ void matrix_qr(int m, int n, float* a);
80
80
  /** distances are supposed to be sorted. Sorts indices with same distance*/
81
81
  void ranklist_handle_ties(int k, int64_t* idx, const float* dis);
82
82
 
83
- /** count the number of comon elements between v1 and v2
83
+ /** count the number of common elements between v1 and v2
84
84
  * algorithm = sorting + bissection to avoid double-counting duplicates
85
85
  */
86
86
  size_t ranklist_intersection_size(
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: faiss
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-12-17 00:00:00.000000000 Z
11
+ date: 2022-01-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -71,6 +71,8 @@ files:
71
71
  - vendor/faiss/faiss/Index.h
72
72
  - vendor/faiss/faiss/Index2Layer.cpp
73
73
  - vendor/faiss/faiss/Index2Layer.h
74
+ - vendor/faiss/faiss/IndexAdditiveQuantizer.cpp
75
+ - vendor/faiss/faiss/IndexAdditiveQuantizer.h
74
76
  - vendor/faiss/faiss/IndexBinary.cpp
75
77
  - vendor/faiss/faiss/IndexBinary.h
76
78
  - vendor/faiss/faiss/IndexBinaryFlat.cpp
@@ -85,10 +87,14 @@ files:
85
87
  - vendor/faiss/faiss/IndexBinaryIVF.h
86
88
  - vendor/faiss/faiss/IndexFlat.cpp
87
89
  - vendor/faiss/faiss/IndexFlat.h
90
+ - vendor/faiss/faiss/IndexFlatCodes.cpp
91
+ - vendor/faiss/faiss/IndexFlatCodes.h
88
92
  - vendor/faiss/faiss/IndexHNSW.cpp
89
93
  - vendor/faiss/faiss/IndexHNSW.h
90
94
  - vendor/faiss/faiss/IndexIVF.cpp
91
95
  - vendor/faiss/faiss/IndexIVF.h
96
+ - vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp
97
+ - vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h
92
98
  - vendor/faiss/faiss/IndexIVFFlat.cpp
93
99
  - vendor/faiss/faiss/IndexIVFFlat.h
94
100
  - vendor/faiss/faiss/IndexIVFPQ.cpp
@@ -117,8 +123,6 @@ files:
117
123
  - vendor/faiss/faiss/IndexRefine.h
118
124
  - vendor/faiss/faiss/IndexReplicas.cpp
119
125
  - vendor/faiss/faiss/IndexReplicas.h
120
- - vendor/faiss/faiss/IndexResidual.cpp
121
- - vendor/faiss/faiss/IndexResidual.h
122
126
  - vendor/faiss/faiss/IndexScalarQuantizer.cpp
123
127
  - vendor/faiss/faiss/IndexScalarQuantizer.h
124
128
  - vendor/faiss/faiss/IndexShards.cpp
@@ -140,6 +144,7 @@ files:
140
144
  - vendor/faiss/faiss/gpu/GpuClonerOptions.h
141
145
  - vendor/faiss/faiss/gpu/GpuDistance.h
142
146
  - vendor/faiss/faiss/gpu/GpuFaissAssert.h
147
+ - vendor/faiss/faiss/gpu/GpuIcmEncoder.h
143
148
  - vendor/faiss/faiss/gpu/GpuIndex.h
144
149
  - vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h
145
150
  - vendor/faiss/faiss/gpu/GpuIndexFlat.h
@@ -209,6 +214,8 @@ files:
209
214
  - vendor/faiss/faiss/impl/io.cpp
210
215
  - vendor/faiss/faiss/impl/io.h
211
216
  - vendor/faiss/faiss/impl/io_macros.h
217
+ - vendor/faiss/faiss/impl/kmeans1d.cpp
218
+ - vendor/faiss/faiss/impl/kmeans1d.h
212
219
  - vendor/faiss/faiss/impl/lattice_Zn.cpp
213
220
  - vendor/faiss/faiss/impl/lattice_Zn.h
214
221
  - vendor/faiss/faiss/impl/platform_macros.h
@@ -278,7 +285,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
278
285
  - !ruby/object:Gem::Version
279
286
  version: '0'
280
287
  requirements: []
281
- rubygems_version: 3.2.32
288
+ rubygems_version: 3.3.3
282
289
  signing_key:
283
290
  specification_version: 4
284
291
  summary: Efficient similarity search and clustering for Ruby
@@ -1,291 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
- #include <faiss/IndexResidual.h>
9
-
10
- #include <algorithm>
11
- #include <cmath>
12
- #include <cstring>
13
-
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/ResultHandler.h>
16
- #include <faiss/utils/distances.h>
17
- #include <faiss/utils/extra_distances.h>
18
- #include <faiss/utils/utils.h>
19
-
20
- namespace faiss {
21
-
22
- /**************************************************************************************
23
- * IndexResidual
24
- **************************************************************************************/
25
-
26
- IndexResidual::IndexResidual(
27
- int d, ///< dimensionality of the input vectors
28
- size_t M, ///< number of subquantizers
29
- size_t nbits, ///< number of bit per subvector index
30
- MetricType metric,
31
- Search_type_t search_type_in)
32
- : Index(d, metric), rq(d, M, nbits), search_type(ST_decompress) {
33
- is_trained = false;
34
- norm_max = norm_min = NAN;
35
- set_search_type(search_type_in);
36
- }
37
-
38
- IndexResidual::IndexResidual(
39
- int d,
40
- const std::vector<size_t>& nbits,
41
- MetricType metric,
42
- Search_type_t search_type_in)
43
- : Index(d, metric), rq(d, nbits), search_type(ST_decompress) {
44
- is_trained = false;
45
- norm_max = norm_min = NAN;
46
- set_search_type(search_type_in);
47
- }
48
-
49
- IndexResidual::IndexResidual() : IndexResidual(0, 0, 0) {}
50
-
51
- void IndexResidual::set_search_type(Search_type_t new_search_type) {
52
- int norm_bits = new_search_type == ST_norm_float ? 32
53
- : new_search_type == ST_norm_qint8 ? 8
54
- : 0;
55
-
56
- FAISS_THROW_IF_NOT(ntotal == 0);
57
-
58
- search_type = new_search_type;
59
- code_size = (rq.tot_bits + norm_bits + 7) / 8;
60
- }
61
-
62
- void IndexResidual::train(idx_t n, const float* x) {
63
- rq.train(n, x);
64
-
65
- std::vector<float> norms(n);
66
- fvec_norms_L2sqr(norms.data(), x, d, n);
67
-
68
- norm_min = HUGE_VALF;
69
- norm_max = -HUGE_VALF;
70
- for (idx_t i = 0; i < n; i++) {
71
- if (norms[i] < norm_min) {
72
- norm_min = norms[i];
73
- }
74
- if (norms[i] > norm_min) {
75
- norm_max = norms[i];
76
- }
77
- }
78
-
79
- is_trained = true;
80
- }
81
-
82
- void IndexResidual::add(idx_t n, const float* x) {
83
- FAISS_THROW_IF_NOT(is_trained);
84
- codes.resize((n + ntotal) * rq.code_size);
85
- if (search_type == ST_decompress || search_type == ST_LUT_nonorm) {
86
- rq.compute_codes(x, &codes[ntotal * rq.code_size], n);
87
- } else {
88
- // should compute codes + compute and quantize norms
89
- FAISS_THROW_MSG("not implemented");
90
- }
91
- ntotal += n;
92
- }
93
-
94
- namespace {
95
-
96
- template <class VectorDistance, class ResultHandler>
97
- void search_with_decompress(
98
- const IndexResidual& ir,
99
- const float* xq,
100
- VectorDistance& vd,
101
- ResultHandler& res) {
102
- const uint8_t* codes = ir.codes.data();
103
- size_t ntotal = ir.ntotal;
104
- size_t code_size = ir.code_size;
105
-
106
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
107
-
108
- #pragma omp parallel for
109
- for (int64_t q = 0; q < res.nq; q++) {
110
- SingleResultHandler resi(res);
111
- resi.begin(q);
112
- std::vector<float> tmp(ir.d);
113
- const float* x = xq + ir.d * q;
114
- for (size_t i = 0; i < ntotal; i++) {
115
- ir.rq.decode(codes + i * code_size, tmp.data(), 1);
116
- float dis = vd(x, tmp.data());
117
- resi.add_result(dis, i);
118
- }
119
- resi.end();
120
- }
121
- }
122
-
123
- } // anonymous namespace
124
-
125
- void IndexResidual::search(
126
- idx_t n,
127
- const float* x,
128
- idx_t k,
129
- float* distances,
130
- idx_t* labels) const {
131
- if (search_type == ST_decompress) {
132
- if (metric_type == METRIC_L2) {
133
- using VD = VectorDistance<METRIC_L2>;
134
- VD vd = {size_t(d), metric_arg};
135
- HeapResultHandler<VD::C> rh(n, distances, labels, k);
136
- search_with_decompress(*this, x, vd, rh);
137
- } else if (metric_type == METRIC_INNER_PRODUCT) {
138
- using VD = VectorDistance<METRIC_INNER_PRODUCT>;
139
- VD vd = {size_t(d), metric_arg};
140
- HeapResultHandler<VD::C> rh(n, distances, labels, k);
141
- search_with_decompress(*this, x, vd, rh);
142
- }
143
- } else {
144
- FAISS_THROW_MSG("not implemented");
145
- }
146
- }
147
-
148
- void IndexResidual::reset() {
149
- codes.clear();
150
- ntotal = 0;
151
- }
152
-
153
- size_t IndexResidual::sa_code_size() const {
154
- return code_size;
155
- }
156
-
157
- void IndexResidual::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
158
- return rq.compute_codes(x, bytes, n);
159
- }
160
-
161
- void IndexResidual::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
162
- return rq.decode(bytes, x, n);
163
- }
164
-
165
- /**************************************************************************************
166
- * ResidualCoarseQuantizer
167
- **************************************************************************************/
168
-
169
- ResidualCoarseQuantizer::ResidualCoarseQuantizer(
170
- int d, ///< dimensionality of the input vectors
171
- size_t M, ///< number of subquantizers
172
- size_t nbits, ///< number of bit per subvector index
173
- MetricType metric)
174
- : Index(d, metric), rq(d, M, nbits), beam_factor(4.0) {
175
- FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
176
- is_trained = false;
177
- }
178
-
179
- ResidualCoarseQuantizer::ResidualCoarseQuantizer(
180
- int d,
181
- const std::vector<size_t>& nbits,
182
- MetricType metric)
183
- : Index(d, metric), rq(d, nbits), beam_factor(4.0) {
184
- FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
185
- is_trained = false;
186
- }
187
-
188
- ResidualCoarseQuantizer::ResidualCoarseQuantizer() {}
189
-
190
- void ResidualCoarseQuantizer::train(idx_t n, const float* x) {
191
- rq.train(n, x);
192
- is_trained = true;
193
- ntotal = (idx_t)1 << rq.tot_bits;
194
- }
195
-
196
- void ResidualCoarseQuantizer::add(idx_t, const float*) {
197
- FAISS_THROW_MSG("not applicable");
198
- }
199
-
200
- void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
201
- centroid_norms.resize(0);
202
- beam_factor = new_beam_factor;
203
- if (new_beam_factor > 0) {
204
- FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
205
- return;
206
- }
207
-
208
- if (metric_type == METRIC_L2) {
209
- centroid_norms.resize((size_t)1 << rq.tot_bits);
210
- rq.compute_centroid_norms(centroid_norms.data());
211
- }
212
- }
213
-
214
- void ResidualCoarseQuantizer::search(
215
- idx_t n,
216
- const float* x,
217
- idx_t k,
218
- float* distances,
219
- idx_t* labels) const {
220
- if (beam_factor < 0) {
221
- if (metric_type == METRIC_INNER_PRODUCT) {
222
- rq.knn_exact_inner_product(n, x, k, distances, labels);
223
- } else if (metric_type == METRIC_L2) {
224
- FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
225
- rq.knn_exact_L2(n, x, k, distances, labels, centroid_norms.data());
226
- }
227
- return;
228
- }
229
-
230
- int beam_size = int(k * beam_factor);
231
-
232
- size_t memory_per_point = rq.memory_per_point(beam_size);
233
-
234
- /*
235
-
236
- printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
237
- memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
238
- */
239
- if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
240
- // then split queries to reduce temp memory
241
- idx_t bs = rq.max_mem_distances / memory_per_point;
242
- if (bs == 0) {
243
- bs = 1; // otherwise we can't do much
244
- }
245
- if (verbose) {
246
- printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
247
- int(n),
248
- int(bs));
249
- }
250
- for (idx_t i0 = 0; i0 < n; i0 += bs) {
251
- idx_t i1 = std::min(n, i0 + bs);
252
- search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
253
- InterruptCallback::check();
254
- }
255
- return;
256
- }
257
-
258
- std::vector<int32_t> codes(beam_size * rq.M * n);
259
- std::vector<float> beam_distances(n * beam_size);
260
-
261
- rq.refine_beam(
262
- n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
263
-
264
- #pragma omp parallel for if (n > 4000)
265
- for (idx_t i = 0; i < n; i++) {
266
- memcpy(distances + i * k,
267
- beam_distances.data() + beam_size * i,
268
- k * sizeof(distances[0]));
269
-
270
- const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
271
- for (idx_t j = 0; j < k; j++) {
272
- idx_t l = 0;
273
- int shift = 0;
274
- for (int m = 0; m < rq.M; m++) {
275
- l |= (*codes_i++) << shift;
276
- shift += rq.nbits[m];
277
- }
278
- labels[i * k + j] = l;
279
- }
280
- }
281
- }
282
-
283
- void ResidualCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
284
- rq.decode_64bit(key, recons);
285
- }
286
-
287
- void ResidualCoarseQuantizer::reset() {
288
- FAISS_THROW_MSG("not applicable");
289
- }
290
-
291
- } // namespace faiss