faiss 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -19,6 +19,9 @@ Index* index_factory(
19
19
  const char* description,
20
20
  MetricType metric = METRIC_L2);
21
21
 
22
+ /// set to > 0 to get more logs from index_factory
23
+ FAISS_API extern int index_factory_verbose;
24
+
22
25
  IndexBinary* index_binary_factory(int d, const char* description);
23
26
 
24
27
  } // namespace faiss
@@ -105,8 +105,9 @@ void exhaustive_inner_product_seq(
105
105
  size_t ny,
106
106
  ResultHandler& res) {
107
107
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
108
+ int nt = std::min(int(nx), omp_get_max_threads());
108
109
 
109
- #pragma omp parallel
110
+ #pragma omp parallel num_threads(nt)
110
111
  {
111
112
  SingleResultHandler resi(res);
112
113
  #pragma omp for
@@ -135,8 +136,9 @@ void exhaustive_L2sqr_seq(
135
136
  size_t ny,
136
137
  ResultHandler& res) {
137
138
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
+ int nt = std::min(int(nx), omp_get_max_threads());
138
140
 
139
- #pragma omp parallel
141
+ #pragma omp parallel num_threads(nt)
140
142
  {
141
143
  SingleResultHandler resi(res);
142
144
  #pragma omp for
@@ -40,7 +40,7 @@ float fvec_Linf(const float* x, const float* y, size_t d);
40
40
  * @param nq nb of query vectors
41
41
  * @param nb nb of database vectors
42
42
  * @param xq query vectors (size nq * d)
43
- * @param xb database vectros (size nb * d)
43
+ * @param xb database vectors (size nb * d)
44
44
  * @param dis output distances (size nq * nb)
45
45
  * @param ldq,ldb, ldd strides for the matrices
46
46
  */
@@ -63,7 +63,7 @@ void fvec_inner_products_ny(
63
63
  size_t d,
64
64
  size_t ny);
65
65
 
66
- /* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
66
+ /* compute ny square L2 distance between x and a set of contiguous y vectors */
67
67
  void fvec_L2sqr_ny(
68
68
  float* dis,
69
69
  const float* x,
@@ -87,7 +87,7 @@ void fvec_norms_L2sqr(float* norms, const float* x, size_t d, size_t nx);
87
87
  /* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
88
88
  void fvec_renorm_L2(size_t d, size_t nx, float* x);
89
89
 
90
- /* This function exists because the Torch counterpart is extremly slow
90
+ /* This function exists because the Torch counterpart is extremely slow
91
91
  (not multi-threaded + unexpected overhead even in single thread).
92
92
  It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
93
93
  void inner_product_to_L2sqr(
@@ -97,6 +97,39 @@ void inner_product_to_L2sqr(
97
97
  size_t n1,
98
98
  size_t n2);
99
99
 
100
+ /*********************************************************
101
+ * Vector to vector functions
102
+ *********************************************************/
103
+
104
+ /** compute c := a + b for vectors
105
+ *
106
+ * c and a can overlap, c and b can overlap
107
+ *
108
+ * @param a size d
109
+ * @param b size d
110
+ * @param c size d
111
+ */
112
+ void fvec_add(size_t d, const float* a, const float* b, float* c);
113
+
114
+ /** compute c := a + b for a, c vectors and b a scalar
115
+ *
116
+ * c and a can overlap
117
+ *
118
+ * @param a size d
119
+ * @param c size d
120
+ */
121
+ void fvec_add(size_t d, const float* a, float b, float* c);
122
+
123
+ /** compute c := a - b for vectors
124
+ *
125
+ * c and a can overlap, c and b can overlap
126
+ *
127
+ * @param a size d
128
+ * @param b size d
129
+ * @param c size d
130
+ */
131
+ void fvec_sub(size_t d, const float* a, const float* b, float* c);
132
+
100
133
  /***************************************************************************
101
134
  * Compute a subset of distances
102
135
  ***************************************************************************/
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
+ #include <algorithm>
12
13
  #include <cassert>
13
14
  #include <cmath>
14
15
  #include <cstdio>
@@ -973,4 +974,53 @@ void compute_PQ_dis_tables_dsub2(
973
974
  }
974
975
  }
975
976
 
977
+ /*********************************************************
978
+ * Vector to vector functions
979
+ *********************************************************/
980
+
981
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
982
+ size_t i;
983
+ for (i = 0; i + 7 < d; i += 8) {
984
+ simd8float32 ci, ai, bi;
985
+ ai.loadu(a + i);
986
+ bi.loadu(b + i);
987
+ ci = ai - bi;
988
+ ci.storeu(c + i);
989
+ }
990
+ // finish non-multiple of 8 remainder
991
+ for (; i < d; i++) {
992
+ c[i] = a[i] - b[i];
993
+ }
994
+ }
995
+
996
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
997
+ size_t i;
998
+ for (i = 0; i + 7 < d; i += 8) {
999
+ simd8float32 ci, ai, bi;
1000
+ ai.loadu(a + i);
1001
+ bi.loadu(b + i);
1002
+ ci = ai + bi;
1003
+ ci.storeu(c + i);
1004
+ }
1005
+ // finish non-multiple of 8 remainder
1006
+ for (; i < d; i++) {
1007
+ c[i] = a[i] + b[i];
1008
+ }
1009
+ }
1010
+
1011
+ void fvec_add(size_t d, const float* a, float b, float* c) {
1012
+ size_t i;
1013
+ simd8float32 bv(b);
1014
+ for (i = 0; i + 7 < d; i += 8) {
1015
+ simd8float32 ci, ai, bi;
1016
+ ai.loadu(a + i);
1017
+ ci = ai + bv;
1018
+ ci.storeu(c + i);
1019
+ }
1020
+ // finish non-multiple of 8 remainder
1021
+ for (; i < d; i++) {
1022
+ c[i] = a[i] + b;
1023
+ }
1024
+ }
1025
+
976
1026
  } // namespace faiss
@@ -80,7 +80,7 @@ void matrix_qr(int m, int n, float* a);
80
80
  /** distances are supposed to be sorted. Sorts indices with same distance*/
81
81
  void ranklist_handle_ties(int k, int64_t* idx, const float* dis);
82
82
 
83
- /** count the number of comon elements between v1 and v2
83
+ /** count the number of common elements between v1 and v2
84
84
  * algorithm = sorting + bissection to avoid double-counting duplicates
85
85
  */
86
86
  size_t ranklist_intersection_size(
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: faiss
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-12-17 00:00:00.000000000 Z
11
+ date: 2022-01-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -71,6 +71,8 @@ files:
71
71
  - vendor/faiss/faiss/Index.h
72
72
  - vendor/faiss/faiss/Index2Layer.cpp
73
73
  - vendor/faiss/faiss/Index2Layer.h
74
+ - vendor/faiss/faiss/IndexAdditiveQuantizer.cpp
75
+ - vendor/faiss/faiss/IndexAdditiveQuantizer.h
74
76
  - vendor/faiss/faiss/IndexBinary.cpp
75
77
  - vendor/faiss/faiss/IndexBinary.h
76
78
  - vendor/faiss/faiss/IndexBinaryFlat.cpp
@@ -85,10 +87,14 @@ files:
85
87
  - vendor/faiss/faiss/IndexBinaryIVF.h
86
88
  - vendor/faiss/faiss/IndexFlat.cpp
87
89
  - vendor/faiss/faiss/IndexFlat.h
90
+ - vendor/faiss/faiss/IndexFlatCodes.cpp
91
+ - vendor/faiss/faiss/IndexFlatCodes.h
88
92
  - vendor/faiss/faiss/IndexHNSW.cpp
89
93
  - vendor/faiss/faiss/IndexHNSW.h
90
94
  - vendor/faiss/faiss/IndexIVF.cpp
91
95
  - vendor/faiss/faiss/IndexIVF.h
96
+ - vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp
97
+ - vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h
92
98
  - vendor/faiss/faiss/IndexIVFFlat.cpp
93
99
  - vendor/faiss/faiss/IndexIVFFlat.h
94
100
  - vendor/faiss/faiss/IndexIVFPQ.cpp
@@ -117,8 +123,6 @@ files:
117
123
  - vendor/faiss/faiss/IndexRefine.h
118
124
  - vendor/faiss/faiss/IndexReplicas.cpp
119
125
  - vendor/faiss/faiss/IndexReplicas.h
120
- - vendor/faiss/faiss/IndexResidual.cpp
121
- - vendor/faiss/faiss/IndexResidual.h
122
126
  - vendor/faiss/faiss/IndexScalarQuantizer.cpp
123
127
  - vendor/faiss/faiss/IndexScalarQuantizer.h
124
128
  - vendor/faiss/faiss/IndexShards.cpp
@@ -140,6 +144,7 @@ files:
140
144
  - vendor/faiss/faiss/gpu/GpuClonerOptions.h
141
145
  - vendor/faiss/faiss/gpu/GpuDistance.h
142
146
  - vendor/faiss/faiss/gpu/GpuFaissAssert.h
147
+ - vendor/faiss/faiss/gpu/GpuIcmEncoder.h
143
148
  - vendor/faiss/faiss/gpu/GpuIndex.h
144
149
  - vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h
145
150
  - vendor/faiss/faiss/gpu/GpuIndexFlat.h
@@ -209,6 +214,8 @@ files:
209
214
  - vendor/faiss/faiss/impl/io.cpp
210
215
  - vendor/faiss/faiss/impl/io.h
211
216
  - vendor/faiss/faiss/impl/io_macros.h
217
+ - vendor/faiss/faiss/impl/kmeans1d.cpp
218
+ - vendor/faiss/faiss/impl/kmeans1d.h
212
219
  - vendor/faiss/faiss/impl/lattice_Zn.cpp
213
220
  - vendor/faiss/faiss/impl/lattice_Zn.h
214
221
  - vendor/faiss/faiss/impl/platform_macros.h
@@ -278,7 +285,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
278
285
  - !ruby/object:Gem::Version
279
286
  version: '0'
280
287
  requirements: []
281
- rubygems_version: 3.2.32
288
+ rubygems_version: 3.3.3
282
289
  signing_key:
283
290
  specification_version: 4
284
291
  summary: Efficient similarity search and clustering for Ruby
@@ -1,291 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
- #include <faiss/IndexResidual.h>
9
-
10
- #include <algorithm>
11
- #include <cmath>
12
- #include <cstring>
13
-
14
- #include <faiss/impl/FaissAssert.h>
15
- #include <faiss/impl/ResultHandler.h>
16
- #include <faiss/utils/distances.h>
17
- #include <faiss/utils/extra_distances.h>
18
- #include <faiss/utils/utils.h>
19
-
20
- namespace faiss {
21
-
22
- /**************************************************************************************
23
- * IndexResidual
24
- **************************************************************************************/
25
-
26
- IndexResidual::IndexResidual(
27
- int d, ///< dimensionality of the input vectors
28
- size_t M, ///< number of subquantizers
29
- size_t nbits, ///< number of bit per subvector index
30
- MetricType metric,
31
- Search_type_t search_type_in)
32
- : Index(d, metric), rq(d, M, nbits), search_type(ST_decompress) {
33
- is_trained = false;
34
- norm_max = norm_min = NAN;
35
- set_search_type(search_type_in);
36
- }
37
-
38
- IndexResidual::IndexResidual(
39
- int d,
40
- const std::vector<size_t>& nbits,
41
- MetricType metric,
42
- Search_type_t search_type_in)
43
- : Index(d, metric), rq(d, nbits), search_type(ST_decompress) {
44
- is_trained = false;
45
- norm_max = norm_min = NAN;
46
- set_search_type(search_type_in);
47
- }
48
-
49
- IndexResidual::IndexResidual() : IndexResidual(0, 0, 0) {}
50
-
51
- void IndexResidual::set_search_type(Search_type_t new_search_type) {
52
- int norm_bits = new_search_type == ST_norm_float ? 32
53
- : new_search_type == ST_norm_qint8 ? 8
54
- : 0;
55
-
56
- FAISS_THROW_IF_NOT(ntotal == 0);
57
-
58
- search_type = new_search_type;
59
- code_size = (rq.tot_bits + norm_bits + 7) / 8;
60
- }
61
-
62
- void IndexResidual::train(idx_t n, const float* x) {
63
- rq.train(n, x);
64
-
65
- std::vector<float> norms(n);
66
- fvec_norms_L2sqr(norms.data(), x, d, n);
67
-
68
- norm_min = HUGE_VALF;
69
- norm_max = -HUGE_VALF;
70
- for (idx_t i = 0; i < n; i++) {
71
- if (norms[i] < norm_min) {
72
- norm_min = norms[i];
73
- }
74
- if (norms[i] > norm_min) {
75
- norm_max = norms[i];
76
- }
77
- }
78
-
79
- is_trained = true;
80
- }
81
-
82
- void IndexResidual::add(idx_t n, const float* x) {
83
- FAISS_THROW_IF_NOT(is_trained);
84
- codes.resize((n + ntotal) * rq.code_size);
85
- if (search_type == ST_decompress || search_type == ST_LUT_nonorm) {
86
- rq.compute_codes(x, &codes[ntotal * rq.code_size], n);
87
- } else {
88
- // should compute codes + compute and quantize norms
89
- FAISS_THROW_MSG("not implemented");
90
- }
91
- ntotal += n;
92
- }
93
-
94
- namespace {
95
-
96
- template <class VectorDistance, class ResultHandler>
97
- void search_with_decompress(
98
- const IndexResidual& ir,
99
- const float* xq,
100
- VectorDistance& vd,
101
- ResultHandler& res) {
102
- const uint8_t* codes = ir.codes.data();
103
- size_t ntotal = ir.ntotal;
104
- size_t code_size = ir.code_size;
105
-
106
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
107
-
108
- #pragma omp parallel for
109
- for (int64_t q = 0; q < res.nq; q++) {
110
- SingleResultHandler resi(res);
111
- resi.begin(q);
112
- std::vector<float> tmp(ir.d);
113
- const float* x = xq + ir.d * q;
114
- for (size_t i = 0; i < ntotal; i++) {
115
- ir.rq.decode(codes + i * code_size, tmp.data(), 1);
116
- float dis = vd(x, tmp.data());
117
- resi.add_result(dis, i);
118
- }
119
- resi.end();
120
- }
121
- }
122
-
123
- } // anonymous namespace
124
-
125
- void IndexResidual::search(
126
- idx_t n,
127
- const float* x,
128
- idx_t k,
129
- float* distances,
130
- idx_t* labels) const {
131
- if (search_type == ST_decompress) {
132
- if (metric_type == METRIC_L2) {
133
- using VD = VectorDistance<METRIC_L2>;
134
- VD vd = {size_t(d), metric_arg};
135
- HeapResultHandler<VD::C> rh(n, distances, labels, k);
136
- search_with_decompress(*this, x, vd, rh);
137
- } else if (metric_type == METRIC_INNER_PRODUCT) {
138
- using VD = VectorDistance<METRIC_INNER_PRODUCT>;
139
- VD vd = {size_t(d), metric_arg};
140
- HeapResultHandler<VD::C> rh(n, distances, labels, k);
141
- search_with_decompress(*this, x, vd, rh);
142
- }
143
- } else {
144
- FAISS_THROW_MSG("not implemented");
145
- }
146
- }
147
-
148
- void IndexResidual::reset() {
149
- codes.clear();
150
- ntotal = 0;
151
- }
152
-
153
- size_t IndexResidual::sa_code_size() const {
154
- return code_size;
155
- }
156
-
157
- void IndexResidual::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
158
- return rq.compute_codes(x, bytes, n);
159
- }
160
-
161
- void IndexResidual::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
162
- return rq.decode(bytes, x, n);
163
- }
164
-
165
- /**************************************************************************************
166
- * ResidualCoarseQuantizer
167
- **************************************************************************************/
168
-
169
- ResidualCoarseQuantizer::ResidualCoarseQuantizer(
170
- int d, ///< dimensionality of the input vectors
171
- size_t M, ///< number of subquantizers
172
- size_t nbits, ///< number of bit per subvector index
173
- MetricType metric)
174
- : Index(d, metric), rq(d, M, nbits), beam_factor(4.0) {
175
- FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
176
- is_trained = false;
177
- }
178
-
179
- ResidualCoarseQuantizer::ResidualCoarseQuantizer(
180
- int d,
181
- const std::vector<size_t>& nbits,
182
- MetricType metric)
183
- : Index(d, metric), rq(d, nbits), beam_factor(4.0) {
184
- FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
185
- is_trained = false;
186
- }
187
-
188
- ResidualCoarseQuantizer::ResidualCoarseQuantizer() {}
189
-
190
- void ResidualCoarseQuantizer::train(idx_t n, const float* x) {
191
- rq.train(n, x);
192
- is_trained = true;
193
- ntotal = (idx_t)1 << rq.tot_bits;
194
- }
195
-
196
- void ResidualCoarseQuantizer::add(idx_t, const float*) {
197
- FAISS_THROW_MSG("not applicable");
198
- }
199
-
200
- void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
201
- centroid_norms.resize(0);
202
- beam_factor = new_beam_factor;
203
- if (new_beam_factor > 0) {
204
- FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
205
- return;
206
- }
207
-
208
- if (metric_type == METRIC_L2) {
209
- centroid_norms.resize((size_t)1 << rq.tot_bits);
210
- rq.compute_centroid_norms(centroid_norms.data());
211
- }
212
- }
213
-
214
- void ResidualCoarseQuantizer::search(
215
- idx_t n,
216
- const float* x,
217
- idx_t k,
218
- float* distances,
219
- idx_t* labels) const {
220
- if (beam_factor < 0) {
221
- if (metric_type == METRIC_INNER_PRODUCT) {
222
- rq.knn_exact_inner_product(n, x, k, distances, labels);
223
- } else if (metric_type == METRIC_L2) {
224
- FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
225
- rq.knn_exact_L2(n, x, k, distances, labels, centroid_norms.data());
226
- }
227
- return;
228
- }
229
-
230
- int beam_size = int(k * beam_factor);
231
-
232
- size_t memory_per_point = rq.memory_per_point(beam_size);
233
-
234
- /*
235
-
236
- printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
237
- memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
238
- */
239
- if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
240
- // then split queries to reduce temp memory
241
- idx_t bs = rq.max_mem_distances / memory_per_point;
242
- if (bs == 0) {
243
- bs = 1; // otherwise we can't do much
244
- }
245
- if (verbose) {
246
- printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
247
- int(n),
248
- int(bs));
249
- }
250
- for (idx_t i0 = 0; i0 < n; i0 += bs) {
251
- idx_t i1 = std::min(n, i0 + bs);
252
- search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
253
- InterruptCallback::check();
254
- }
255
- return;
256
- }
257
-
258
- std::vector<int32_t> codes(beam_size * rq.M * n);
259
- std::vector<float> beam_distances(n * beam_size);
260
-
261
- rq.refine_beam(
262
- n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
263
-
264
- #pragma omp parallel for if (n > 4000)
265
- for (idx_t i = 0; i < n; i++) {
266
- memcpy(distances + i * k,
267
- beam_distances.data() + beam_size * i,
268
- k * sizeof(distances[0]));
269
-
270
- const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
271
- for (idx_t j = 0; j < k; j++) {
272
- idx_t l = 0;
273
- int shift = 0;
274
- for (int m = 0; m < rq.M; m++) {
275
- l |= (*codes_i++) << shift;
276
- shift += rq.nbits[m];
277
- }
278
- labels[i * k + j] = l;
279
- }
280
- }
281
- }
282
-
283
- void ResidualCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
284
- rq.decode_64bit(key, recons);
285
- }
286
-
287
- void ResidualCoarseQuantizer::reset() {
288
- FAISS_THROW_MSG("not applicable");
289
- }
290
-
291
- } // namespace faiss