faiss 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,9 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #include <faiss/impl/FaissAssert.h>
11
8
  #include <faiss/impl/LocalSearchQuantizer.h>
12
9
 
13
10
  #include <cstddef>
@@ -18,6 +15,9 @@
18
15
 
19
16
  #include <algorithm>
20
17
 
18
+ #include <faiss/Clustering.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
21
  #include <faiss/utils/distances.h>
22
22
  #include <faiss/utils/hamming.h> // BitstringWriter
23
23
  #include <faiss/utils/utils.h>
@@ -42,18 +42,6 @@ void sgetri_(
42
42
  FINTEGER* lwork,
43
43
  FINTEGER* info);
44
44
 
45
- // solves a system of linear equations
46
- void sgetrs_(
47
- const char* trans,
48
- FINTEGER* n,
49
- FINTEGER* nrhs,
50
- float* A,
51
- FINTEGER* lda,
52
- FINTEGER* ipiv,
53
- float* b,
54
- FINTEGER* ldb,
55
- FINTEGER* info);
56
-
57
45
  // general matrix multiplication
58
46
  int sgemm_(
59
47
  const char* transa,
@@ -69,26 +57,73 @@ int sgemm_(
69
57
  float* beta,
70
58
  float* c,
71
59
  FINTEGER* ldc);
60
+
61
+ // LU decomoposition of a general matrix
62
+ void dgetrf_(
63
+ FINTEGER* m,
64
+ FINTEGER* n,
65
+ double* a,
66
+ FINTEGER* lda,
67
+ FINTEGER* ipiv,
68
+ FINTEGER* info);
69
+
70
+ // generate inverse of a matrix given its LU decomposition
71
+ void dgetri_(
72
+ FINTEGER* n,
73
+ double* a,
74
+ FINTEGER* lda,
75
+ FINTEGER* ipiv,
76
+ double* work,
77
+ FINTEGER* lwork,
78
+ FINTEGER* info);
79
+
80
+ // general matrix multiplication
81
+ int dgemm_(
82
+ const char* transa,
83
+ const char* transb,
84
+ FINTEGER* m,
85
+ FINTEGER* n,
86
+ FINTEGER* k,
87
+ const double* alpha,
88
+ const double* a,
89
+ FINTEGER* lda,
90
+ const double* b,
91
+ FINTEGER* ldb,
92
+ double* beta,
93
+ double* c,
94
+ FINTEGER* ldc);
72
95
  }
73
96
 
74
97
  namespace {
75
98
 
99
+ void fmat_inverse(float* a, int n) {
100
+ int info;
101
+ int lwork = n * n;
102
+ std::vector<int> ipiv(n);
103
+ std::vector<float> workspace(lwork);
104
+
105
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
106
+ FAISS_THROW_IF_NOT(info == 0);
107
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
108
+ FAISS_THROW_IF_NOT(info == 0);
109
+ }
110
+
76
111
  // c and a and b can overlap
77
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
112
+ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
78
113
  for (size_t i = 0; i < d; i++) {
79
114
  c[i] = a[i] + b[i];
80
115
  }
81
116
  }
82
117
 
83
- void fmat_inverse(float* a, int n) {
118
+ void dmat_inverse(double* a, int n) {
84
119
  int info;
85
120
  int lwork = n * n;
86
121
  std::vector<int> ipiv(n);
87
- std::vector<float> workspace(lwork);
122
+ std::vector<double> workspace(lwork);
88
123
 
89
- sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
124
+ dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
125
  FAISS_THROW_IF_NOT(info == 0);
91
- sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
126
+ dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
127
  FAISS_THROW_IF_NOT(info == 0);
93
128
  }
94
129
 
@@ -107,18 +142,15 @@ void random_int32(
107
142
 
108
143
  namespace faiss {
109
144
 
110
- LSQTimer lsq_timer;
111
-
112
- LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
- FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
-
115
- this->d = d;
116
- this->M = M;
117
- this->nbits = std::vector<size_t>(M, nbits);
118
-
119
- // set derived values
120
- set_derived_values();
145
+ lsq::LSQTimer lsq_timer;
146
+ using lsq::LSQTimerScope;
121
147
 
148
+ LocalSearchQuantizer::LocalSearchQuantizer(
149
+ size_t d,
150
+ size_t M,
151
+ size_t nbits,
152
+ Search_type_t search_type)
153
+ : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
122
154
  is_trained = false;
123
155
  verbose = false;
124
156
 
@@ -138,15 +170,23 @@ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
138
170
 
139
171
  random_seed = 0x12345;
140
172
  std::srand(random_seed);
173
+
174
+ icm_encoder_factory = nullptr;
141
175
  }
142
176
 
177
+ LocalSearchQuantizer::~LocalSearchQuantizer() {
178
+ delete icm_encoder_factory;
179
+ }
180
+
181
+ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
182
+
143
183
  void LocalSearchQuantizer::train(size_t n, const float* x) {
144
184
  FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
185
  FAISS_THROW_IF_NOT(nperts <= M);
146
186
 
147
187
  lsq_timer.reset();
188
+ LSQTimerScope scope(&lsq_timer, "train");
148
189
  if (verbose) {
149
- lsq_timer.start("train");
150
190
  printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
191
  M,
152
192
  n,
@@ -209,7 +249,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
209
249
  }
210
250
 
211
251
  // refine codes
212
- icm_encode(x, codes.data(), n, train_ils_iters, gen);
252
+ icm_encode(codes.data(), x, n, train_ils_iters, gen);
213
253
 
214
254
  if (verbose) {
215
255
  float obj = evaluate(codes.data(), x, n);
@@ -217,25 +257,52 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
217
257
  }
218
258
  }
219
259
 
260
+ is_trained = true;
261
+ {
262
+ std::vector<float> x_recons(n * d);
263
+ std::vector<float> norms(n);
264
+ decode_unpacked(codes.data(), x_recons.data(), n);
265
+ fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
266
+
267
+ norm_min = HUGE_VALF;
268
+ norm_max = -HUGE_VALF;
269
+ for (idx_t i = 0; i < n; i++) {
270
+ if (norms[i] < norm_min) {
271
+ norm_min = norms[i];
272
+ }
273
+ if (norms[i] > norm_max) {
274
+ norm_max = norms[i];
275
+ }
276
+ }
277
+
278
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
279
+ size_t k = (1 << 8);
280
+ if (search_type == ST_norm_cqint4) {
281
+ k = (1 << 4);
282
+ }
283
+ Clustering1D clus(k);
284
+ clus.train_exact(n, norms.data());
285
+ qnorm.add(clus.k, clus.centroids.data());
286
+ }
287
+ }
288
+
220
289
  if (verbose) {
221
- lsq_timer.end("train");
222
290
  float obj = evaluate(codes.data(), x, n);
291
+ scope.finish();
223
292
  printf("After training: obj = %lf\n", obj);
224
293
 
225
294
  printf("Time statistic:\n");
226
- for (const auto& it : lsq_timer.duration) {
227
- printf("\t%s time: %lf s\n", it.first.data(), it.second);
295
+ for (const auto& it : lsq_timer.t) {
296
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
228
297
  }
229
298
  }
230
-
231
- is_trained = true;
232
299
  }
233
300
 
234
301
  void LocalSearchQuantizer::perturb_codebooks(
235
302
  float T,
236
303
  const std::vector<float>& stddev,
237
304
  std::mt19937& gen) {
238
- lsq_timer.start("perturb_codebooks");
305
+ LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
239
306
 
240
307
  std::vector<std::normal_distribution<float>> distribs;
241
308
  for (size_t i = 0; i < d; i++) {
@@ -249,8 +316,6 @@ void LocalSearchQuantizer::perturb_codebooks(
249
316
  }
250
317
  }
251
318
  }
252
-
253
- lsq_timer.end("perturb_codebooks");
254
319
  }
255
320
 
256
321
  void LocalSearchQuantizer::compute_codes(
@@ -258,23 +323,26 @@ void LocalSearchQuantizer::compute_codes(
258
323
  uint8_t* codes_out,
259
324
  size_t n) const {
260
325
  FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
326
+
327
+ lsq_timer.reset();
328
+ LSQTimerScope scope(&lsq_timer, "encode");
261
329
  if (verbose) {
262
- lsq_timer.reset();
263
330
  printf("Encoding %zd vectors...\n", n);
264
- lsq_timer.start("encode");
265
331
  }
266
332
 
267
333
  std::vector<int32_t> codes(n * M);
268
334
  std::mt19937 gen(random_seed);
269
335
  random_int32(codes, 0, K - 1, gen);
270
336
 
271
- icm_encode(x, codes.data(), n, encode_ils_iters, gen);
337
+ icm_encode(codes.data(), x, n, encode_ils_iters, gen);
272
338
  pack_codes(n, codes.data(), codes_out);
273
339
 
274
340
  if (verbose) {
275
- lsq_timer.end("encode");
276
- double t = lsq_timer.get("encode");
277
- printf("Time to encode %zd vectors: %lf s\n", n, t);
341
+ scope.finish();
342
+ printf("Time statistic:\n");
343
+ for (const auto& it : lsq_timer.t) {
344
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
345
+ }
278
346
  }
279
347
  }
280
348
 
@@ -298,73 +366,144 @@ void LocalSearchQuantizer::update_codebooks(
298
366
  const float* x,
299
367
  const int32_t* codes,
300
368
  size_t n) {
301
- lsq_timer.start("update_codebooks");
369
+ LSQTimerScope scope(&lsq_timer, "update_codebooks");
370
+
371
+ if (!update_codebooks_with_double) {
372
+ // allocate memory
373
+ // bb = B'B, bx = BX
374
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
375
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
376
+
377
+ // compute B'B
378
+ for (size_t i = 0; i < n; i++) {
379
+ for (size_t m = 0; m < M; m++) {
380
+ int32_t code1 = codes[i * M + m];
381
+ int32_t idx1 = m * K + code1;
382
+ bb[idx1 * M * K + idx1] += 1;
383
+
384
+ for (size_t m2 = m + 1; m2 < M; m2++) {
385
+ int32_t code2 = codes[i * M + m2];
386
+ int32_t idx2 = m2 * K + code2;
387
+ bb[idx1 * M * K + idx2] += 1;
388
+ bb[idx2 * M * K + idx1] += 1;
389
+ }
390
+ }
391
+ }
302
392
 
303
- // allocate memory
304
- // bb = B'B, bx = BX
305
- std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
- std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
393
+ // add a regularization term to B'B
394
+ for (int64_t i = 0; i < M * K; i++) {
395
+ bb[i * (M * K) + i] += lambd;
396
+ }
307
397
 
308
- // compute B'B
309
- for (size_t i = 0; i < n; i++) {
310
- for (size_t m = 0; m < M; m++) {
311
- int32_t code1 = codes[i * M + m];
312
- int32_t idx1 = m * K + code1;
313
- bb[idx1 * M * K + idx1] += 1;
314
-
315
- for (size_t m2 = m + 1; m2 < M; m2++) {
316
- int32_t code2 = codes[i * M + m2];
317
- int32_t idx2 = m2 * K + code2;
318
- bb[idx1 * M * K + idx2] += 1;
319
- bb[idx2 * M * K + idx1] += 1;
398
+ // compute (B'B)^(-1)
399
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
400
+
401
+ // compute BX
402
+ for (size_t i = 0; i < n; i++) {
403
+ for (size_t m = 0; m < M; m++) {
404
+ int32_t code = codes[i * M + m];
405
+ float* data = bx.data() + (m * K + code) * d;
406
+ fvec_add(d, data, x + i * d, data);
320
407
  }
321
408
  }
322
- }
323
409
 
324
- // add a regularization term to B'B
325
- for (int64_t i = 0; i < M * K; i++) {
326
- bb[i * (M * K) + i] += lambd;
327
- }
410
+ // compute C = (B'B)^(-1) @ BX
411
+ //
412
+ // NOTE: LAPACK use column major order
413
+ // out = alpha * op(A) * op(B) + beta * C
414
+ FINTEGER nrows_A = d;
415
+ FINTEGER ncols_A = M * K;
416
+
417
+ FINTEGER nrows_B = M * K;
418
+ FINTEGER ncols_B = M * K;
419
+
420
+ float alpha = 1.0f;
421
+ float beta = 0.0f;
422
+ sgemm_("Not Transposed",
423
+ "Not Transposed",
424
+ &nrows_A, // nrows of op(A)
425
+ &ncols_B, // ncols of op(B)
426
+ &ncols_A, // ncols of op(A)
427
+ &alpha,
428
+ bx.data(),
429
+ &nrows_A, // nrows of A
430
+ bb.data(),
431
+ &nrows_B, // nrows of B
432
+ &beta,
433
+ codebooks.data(),
434
+ &nrows_A); // nrows of output
435
+
436
+ } else {
437
+ // allocate memory
438
+ // bb = B'B, bx = BX
439
+ std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
440
+ std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
441
+
442
+ // compute B'B
443
+ for (size_t i = 0; i < n; i++) {
444
+ for (size_t m = 0; m < M; m++) {
445
+ int32_t code1 = codes[i * M + m];
446
+ int32_t idx1 = m * K + code1;
447
+ bb[idx1 * M * K + idx1] += 1;
448
+
449
+ for (size_t m2 = m + 1; m2 < M; m2++) {
450
+ int32_t code2 = codes[i * M + m2];
451
+ int32_t idx2 = m2 * K + code2;
452
+ bb[idx1 * M * K + idx2] += 1;
453
+ bb[idx2 * M * K + idx1] += 1;
454
+ }
455
+ }
456
+ }
328
457
 
329
- // compute (B'B)^(-1)
330
- fmat_inverse(bb.data(), M * K); // [M*K, M*K]
458
+ // add a regularization term to B'B
459
+ for (int64_t i = 0; i < M * K; i++) {
460
+ bb[i * (M * K) + i] += lambd;
461
+ }
331
462
 
332
- // compute BX
333
- for (size_t i = 0; i < n; i++) {
334
- for (size_t m = 0; m < M; m++) {
335
- int32_t code = codes[i * M + m];
336
- float* data = bx.data() + (m * K + code) * d;
337
- fvec_add(d, data, x + i * d, data);
463
+ // compute (B'B)^(-1)
464
+ dmat_inverse(bb.data(), M * K); // [M*K, M*K]
465
+
466
+ // compute BX
467
+ for (size_t i = 0; i < n; i++) {
468
+ for (size_t m = 0; m < M; m++) {
469
+ int32_t code = codes[i * M + m];
470
+ double* data = bx.data() + (m * K + code) * d;
471
+ dfvec_add(d, data, x + i * d, data);
472
+ }
338
473
  }
339
- }
340
474
 
341
- // compute C = (B'B)^(-1) @ BX
342
- //
343
- // NOTE: LAPACK use column major order
344
- // out = alpha * op(A) * op(B) + beta * C
345
- FINTEGER nrows_A = d;
346
- FINTEGER ncols_A = M * K;
347
-
348
- FINTEGER nrows_B = M * K;
349
- FINTEGER ncols_B = M * K;
350
-
351
- float alpha = 1.0f;
352
- float beta = 0.0f;
353
- sgemm_("Not Transposed",
354
- "Not Transposed",
355
- &nrows_A, // nrows of op(A)
356
- &ncols_B, // ncols of op(B)
357
- &ncols_A, // ncols of op(A)
358
- &alpha,
359
- bx.data(),
360
- &nrows_A, // nrows of A
361
- bb.data(),
362
- &nrows_B, // nrows of B
363
- &beta,
364
- codebooks.data(),
365
- &nrows_A); // nrows of output
366
-
367
- lsq_timer.end("update_codebooks");
475
+ // compute C = (B'B)^(-1) @ BX
476
+ //
477
+ // NOTE: LAPACK use column major order
478
+ // out = alpha * op(A) * op(B) + beta * C
479
+ FINTEGER nrows_A = d;
480
+ FINTEGER ncols_A = M * K;
481
+
482
+ FINTEGER nrows_B = M * K;
483
+ FINTEGER ncols_B = M * K;
484
+
485
+ std::vector<double> d_codebooks(M * K * d);
486
+
487
+ double alpha = 1.0f;
488
+ double beta = 0.0f;
489
+ dgemm_("Not Transposed",
490
+ "Not Transposed",
491
+ &nrows_A, // nrows of op(A)
492
+ &ncols_B, // ncols of op(B)
493
+ &ncols_A, // ncols of op(A)
494
+ &alpha,
495
+ bx.data(),
496
+ &nrows_A, // nrows of A
497
+ bb.data(),
498
+ &nrows_B, // nrows of B
499
+ &beta,
500
+ d_codebooks.data(),
501
+ &nrows_A); // nrows of output
502
+
503
+ for (size_t i = 0; i < M * K * d; i++) {
504
+ codebooks[i] = (float)d_codebooks[i];
505
+ }
506
+ }
368
507
  }
369
508
 
370
509
  /** encode using iterative conditional mode
@@ -386,15 +525,23 @@ void LocalSearchQuantizer::update_codebooks(
386
525
  * These two terms can be precomputed and store in a look up table.
387
526
  */
388
527
  void LocalSearchQuantizer::icm_encode(
389
- const float* x,
390
528
  int32_t* codes,
529
+ const float* x,
391
530
  size_t n,
392
531
  size_t ils_iters,
393
532
  std::mt19937& gen) const {
394
- lsq_timer.start("icm_encode");
533
+ LSQTimerScope scope(&lsq_timer, "icm_encode");
534
+
535
+ auto factory = icm_encoder_factory;
536
+ std::unique_ptr<lsq::IcmEncoder> icm_encoder;
537
+ if (factory == nullptr) {
538
+ icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
539
+ } else {
540
+ icm_encoder.reset(factory->get(this));
541
+ }
395
542
 
396
- std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
- compute_binary_terms(binaries.data());
543
+ // precompute binary terms for all chunks
544
+ icm_encoder->set_binary_term();
398
545
 
399
546
  const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
547
  for (size_t i = 0; i < n_chunks; i++) {
@@ -410,21 +557,20 @@ void LocalSearchQuantizer::icm_encode(
410
557
 
411
558
  const float* xi = x + i * chunk_size * d;
412
559
  int32_t* codesi = codes + i * chunk_size * M;
413
- icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
560
+ icm_encoder->verbose = (verbose && i == 0);
561
+ icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
414
562
  }
415
-
416
- lsq_timer.end("icm_encode");
417
563
  }
418
564
 
419
- void LocalSearchQuantizer::icm_encode_partial(
420
- size_t index,
421
- const float* x,
565
+ void LocalSearchQuantizer::icm_encode_impl(
422
566
  int32_t* codes,
423
- size_t n,
567
+ const float* x,
424
568
  const float* binaries,
569
+ std::mt19937& gen,
570
+ size_t n,
425
571
  size_t ils_iters,
426
- std::mt19937& gen) const {
427
- std::vector<float> unaries(n * M * K); // [n, M, K]
572
+ bool verbose) const {
573
+ std::vector<float> unaries(n * M * K); // [M, n, K]
428
574
  compute_unary_terms(x, unaries.data(), n);
429
575
 
430
576
  std::vector<int32_t> best_codes;
@@ -438,9 +584,7 @@ void LocalSearchQuantizer::icm_encode_partial(
438
584
  // add perturbation to codes
439
585
  perturb_codes(codes, n, gen);
440
586
 
441
- for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
- icm_encode_step(unaries.data(), binaries, codes, n);
443
- }
587
+ icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
444
588
 
445
589
  std::vector<float> icm_objs(n, 0.0f);
446
590
  evaluate(codes, x, n, icm_objs.data());
@@ -463,7 +607,7 @@ void LocalSearchQuantizer::icm_encode_partial(
463
607
 
464
608
  memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
609
 
466
- if (verbose && index == 0) {
610
+ if (verbose) {
467
611
  printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
612
  iter1,
469
613
  mean_obj,
@@ -474,61 +618,67 @@ void LocalSearchQuantizer::icm_encode_partial(
474
618
  }
475
619
 
476
620
  void LocalSearchQuantizer::icm_encode_step(
621
+ int32_t* codes,
477
622
  const float* unaries,
478
623
  const float* binaries,
479
- int32_t* codes,
480
- size_t n) const {
481
- // condition on the m-th subcode
482
- for (size_t m = 0; m < M; m++) {
483
- std::vector<float> objs(n * K);
484
- #pragma omp parallel for
485
- for (int64_t i = 0; i < n; i++) {
486
- auto u = unaries + i * (M * K) + m * K;
487
- memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
- }
624
+ size_t n,
625
+ size_t n_iters) const {
626
+ FAISS_THROW_IF_NOT(M != 0 && K != 0);
627
+ FAISS_THROW_IF_NOT(binaries != nullptr);
489
628
 
490
- // compute objective function by adding unary
491
- // and binary terms together
492
- for (size_t other_m = 0; other_m < M; other_m++) {
493
- if (other_m == m) {
494
- continue;
629
+ for (size_t iter = 0; iter < n_iters; iter++) {
630
+ // condition on the m-th subcode
631
+ for (size_t m = 0; m < M; m++) {
632
+ std::vector<float> objs(n * K);
633
+ #pragma omp parallel for
634
+ for (int64_t i = 0; i < n; i++) {
635
+ auto u = unaries + m * n * K + i * K;
636
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
495
637
  }
496
638
 
639
+ // compute objective function by adding unary
640
+ // and binary terms together
641
+ for (size_t other_m = 0; other_m < M; other_m++) {
642
+ if (other_m == m) {
643
+ continue;
644
+ }
645
+
497
646
  #pragma omp parallel for
498
- for (int64_t i = 0; i < n; i++) {
499
- for (int32_t code = 0; code < K; code++) {
500
- int32_t code2 = codes[i * M + other_m];
501
- size_t binary_idx =
502
- m * M * K * K + other_m * K * K + code * K + code2;
503
- // binaries[m, other_m, code, code2]
504
- objs[i * K + code] += binaries[binary_idx];
647
+ for (int64_t i = 0; i < n; i++) {
648
+ for (int32_t code = 0; code < K; code++) {
649
+ int32_t code2 = codes[i * M + other_m];
650
+ size_t binary_idx = m * M * K * K + other_m * K * K +
651
+ code * K + code2;
652
+ // binaries[m, other_m, code, code2]
653
+ objs[i * K + code] += binaries[binary_idx];
654
+ }
505
655
  }
506
656
  }
507
- }
508
657
 
509
- // find the optimal value of the m-th subcode
658
+ // find the optimal value of the m-th subcode
510
659
  #pragma omp parallel for
511
- for (int64_t i = 0; i < n; i++) {
512
- float best_obj = HUGE_VALF;
513
- int32_t best_code = 0;
514
- for (size_t code = 0; code < K; code++) {
515
- float obj = objs[i * K + code];
516
- if (obj < best_obj) {
517
- best_obj = obj;
518
- best_code = code;
660
+ for (int64_t i = 0; i < n; i++) {
661
+ float best_obj = HUGE_VALF;
662
+ int32_t best_code = 0;
663
+ for (size_t code = 0; code < K; code++) {
664
+ float obj = objs[i * K + code];
665
+ if (obj < best_obj) {
666
+ best_obj = obj;
667
+ best_code = code;
668
+ }
519
669
  }
670
+ codes[i * M + m] = best_code;
520
671
  }
521
- codes[i * M + m] = best_code;
522
- }
523
672
 
524
- } // loop M
673
+ } // loop M
674
+ }
525
675
  }
526
676
 
527
677
  void LocalSearchQuantizer::perturb_codes(
528
678
  int32_t* codes,
529
679
  size_t n,
530
680
  std::mt19937& gen) const {
531
- lsq_timer.start("perturb_codes");
681
+ LSQTimerScope scope(&lsq_timer, "perturb_codes");
532
682
 
533
683
  std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
684
  std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
@@ -539,12 +689,10 @@ void LocalSearchQuantizer::perturb_codes(
539
689
  codes[i * M + m] = k_distrib(gen);
540
690
  }
541
691
  }
542
-
543
- lsq_timer.end("perturb_codes");
544
692
  }
545
693
 
546
694
  void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
- lsq_timer.start("compute_binary_terms");
695
+ LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
548
696
 
549
697
  #pragma omp parallel for
550
698
  for (int64_t m12 = 0; m12 < M * M; m12++) {
@@ -562,52 +710,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
562
710
  }
563
711
  }
564
712
  }
565
-
566
- lsq_timer.end("compute_binary_terms");
567
713
  }
568
714
 
569
715
  void LocalSearchQuantizer::compute_unary_terms(
570
716
  const float* x,
571
- float* unaries,
717
+ float* unaries, // [M, n, K]
572
718
  size_t n) const {
573
- lsq_timer.start("compute_unary_terms");
719
+ LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
574
720
 
575
- // compute x * codebooks^T
721
+ // compute x * codebook^T for each codebook
576
722
  //
577
723
  // NOTE: LAPACK use column major order
578
724
  // out = alpha * op(A) * op(B) + beta * C
579
- FINTEGER nrows_A = M * K;
580
- FINTEGER ncols_A = d;
581
-
582
- FINTEGER nrows_B = d;
583
- FINTEGER ncols_B = n;
584
-
585
- float alpha = -2.0f;
586
- float beta = 0.0f;
587
- sgemm_("Transposed",
588
- "Not Transposed",
589
- &nrows_A, // nrows of op(A)
590
- &ncols_B, // ncols of op(B)
591
- &ncols_A, // ncols of op(A)
592
- &alpha,
593
- codebooks.data(),
594
- &ncols_A, // nrows of A
595
- x,
596
- &nrows_B, // nrows of B
597
- &beta,
598
- unaries,
599
- &nrows_A); // nrows of output
725
+
726
+ for (size_t m = 0; m < M; m++) {
727
+ FINTEGER nrows_A = K;
728
+ FINTEGER ncols_A = d;
729
+
730
+ FINTEGER nrows_B = d;
731
+ FINTEGER ncols_B = n;
732
+
733
+ float alpha = -2.0f;
734
+ float beta = 0.0f;
735
+ sgemm_("Transposed",
736
+ "Not Transposed",
737
+ &nrows_A, // nrows of op(A)
738
+ &ncols_B, // ncols of op(B)
739
+ &ncols_A, // ncols of op(A)
740
+ &alpha,
741
+ codebooks.data() + m * K * d,
742
+ &ncols_A, // nrows of A
743
+ x,
744
+ &nrows_B, // nrows of B
745
+ &beta,
746
+ unaries + m * n * K,
747
+ &nrows_A); // nrows of output
748
+ }
600
749
 
601
750
  std::vector<float> norms(M * K);
602
751
  fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
752
 
604
753
  #pragma omp parallel for
605
754
  for (int64_t i = 0; i < n; i++) {
606
- float* u = unaries + i * (M * K);
607
- fvec_add(M * K, u, norms.data(), u);
755
+ for (size_t m = 0; m < M; m++) {
756
+ float* u = unaries + m * n * K + i * K;
757
+ fvec_add(K, u, norms.data() + m * K, u);
758
+ }
608
759
  }
609
-
610
- lsq_timer.end("compute_unary_terms");
611
760
  }
612
761
 
613
762
  float LocalSearchQuantizer::evaluate(
@@ -615,7 +764,7 @@ float LocalSearchQuantizer::evaluate(
615
764
  const float* x,
616
765
  size_t n,
617
766
  float* objs) const {
618
- lsq_timer.start("evaluate");
767
+ LSQTimerScope scope(&lsq_timer, "evaluate");
619
768
 
620
769
  // decode
621
770
  std::vector<float> decoded_x(n * d, 0.0f);
@@ -631,7 +780,7 @@ float LocalSearchQuantizer::evaluate(
631
780
  fvec_add(d, decoded_i, c, decoded_i);
632
781
  }
633
782
 
634
- float err = fvec_L2sqr(x + i * d, decoded_i, d);
783
+ float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
635
784
  obj += err;
636
785
 
637
786
  if (objs) {
@@ -639,34 +788,68 @@ float LocalSearchQuantizer::evaluate(
639
788
  }
640
789
  }
641
790
 
642
- lsq_timer.end("evaluate");
643
-
644
791
  obj = obj / n;
645
792
  return obj;
646
793
  }
647
794
 
648
- double LSQTimer::get(const std::string& name) {
649
- return duration[name];
795
+ namespace lsq {
796
+
797
+ IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
798
+ : verbose(false), lsq(lsq) {}
799
+
800
+ void IcmEncoder::set_binary_term() {
801
+ auto M = lsq->M;
802
+ auto K = lsq->K;
803
+ binaries.resize(M * M * K * K);
804
+ lsq->compute_binary_terms(binaries.data());
650
805
  }
651
806
 
652
- void LSQTimer::start(const std::string& name) {
653
- FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
- started[name] = true;
655
- t0[name] = getmillisecs();
807
+ void IcmEncoder::encode(
808
+ int32_t* codes,
809
+ const float* x,
810
+ std::mt19937& gen,
811
+ size_t n,
812
+ size_t ils_iters) const {
813
+ lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
656
814
  }
657
815
 
658
- void LSQTimer::end(const std::string& name) {
659
- FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
- double t1 = getmillisecs();
661
- double sec = (t1 - t0[name]) / 1000;
662
- duration[name] += sec;
663
- started[name] = false;
816
+ double LSQTimer::get(const std::string& name) {
817
+ if (t.count(name) == 0) {
818
+ return 0.0;
819
+ } else {
820
+ return t[name];
821
+ }
822
+ }
823
+
824
+ void LSQTimer::add(const std::string& name, double delta) {
825
+ if (t.count(name) == 0) {
826
+ t[name] = delta;
827
+ } else {
828
+ t[name] += delta;
829
+ }
664
830
  }
665
831
 
666
832
  void LSQTimer::reset() {
667
- duration.clear();
668
- t0.clear();
669
- started.clear();
833
+ t.clear();
834
+ }
835
+
836
+ LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
837
+ : timer(timer), name(name), finished(false) {
838
+ t0 = getmillisecs();
670
839
  }
671
840
 
841
+ void LSQTimerScope::finish() {
842
+ if (!finished) {
843
+ auto delta = getmillisecs() - t0;
844
+ timer->add(name, delta);
845
+ finished = true;
846
+ }
847
+ }
848
+
849
+ LSQTimerScope::~LSQTimerScope() {
850
+ finish();
851
+ }
852
+
853
+ } // namespace lsq
854
+
672
855
  } // namespace faiss