faiss 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,9 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #include <faiss/impl/FaissAssert.h>
11
8
  #include <faiss/impl/LocalSearchQuantizer.h>
12
9
 
13
10
  #include <cstddef>
@@ -18,6 +15,9 @@
18
15
 
19
16
  #include <algorithm>
20
17
 
18
+ #include <faiss/Clustering.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
21
  #include <faiss/utils/distances.h>
22
22
  #include <faiss/utils/hamming.h> // BitstringWriter
23
23
  #include <faiss/utils/utils.h>
@@ -42,18 +42,6 @@ void sgetri_(
42
42
  FINTEGER* lwork,
43
43
  FINTEGER* info);
44
44
 
45
- // solves a system of linear equations
46
- void sgetrs_(
47
- const char* trans,
48
- FINTEGER* n,
49
- FINTEGER* nrhs,
50
- float* A,
51
- FINTEGER* lda,
52
- FINTEGER* ipiv,
53
- float* b,
54
- FINTEGER* ldb,
55
- FINTEGER* info);
56
-
57
45
  // general matrix multiplication
58
46
  int sgemm_(
59
47
  const char* transa,
@@ -69,26 +57,73 @@ int sgemm_(
69
57
  float* beta,
70
58
  float* c,
71
59
  FINTEGER* ldc);
60
+
61
+ // LU decomoposition of a general matrix
62
+ void dgetrf_(
63
+ FINTEGER* m,
64
+ FINTEGER* n,
65
+ double* a,
66
+ FINTEGER* lda,
67
+ FINTEGER* ipiv,
68
+ FINTEGER* info);
69
+
70
+ // generate inverse of a matrix given its LU decomposition
71
+ void dgetri_(
72
+ FINTEGER* n,
73
+ double* a,
74
+ FINTEGER* lda,
75
+ FINTEGER* ipiv,
76
+ double* work,
77
+ FINTEGER* lwork,
78
+ FINTEGER* info);
79
+
80
+ // general matrix multiplication
81
+ int dgemm_(
82
+ const char* transa,
83
+ const char* transb,
84
+ FINTEGER* m,
85
+ FINTEGER* n,
86
+ FINTEGER* k,
87
+ const double* alpha,
88
+ const double* a,
89
+ FINTEGER* lda,
90
+ const double* b,
91
+ FINTEGER* ldb,
92
+ double* beta,
93
+ double* c,
94
+ FINTEGER* ldc);
72
95
  }
73
96
 
74
97
  namespace {
75
98
 
99
+ void fmat_inverse(float* a, int n) {
100
+ int info;
101
+ int lwork = n * n;
102
+ std::vector<int> ipiv(n);
103
+ std::vector<float> workspace(lwork);
104
+
105
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
106
+ FAISS_THROW_IF_NOT(info == 0);
107
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
108
+ FAISS_THROW_IF_NOT(info == 0);
109
+ }
110
+
76
111
  // c and a and b can overlap
77
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
112
+ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
78
113
  for (size_t i = 0; i < d; i++) {
79
114
  c[i] = a[i] + b[i];
80
115
  }
81
116
  }
82
117
 
83
- void fmat_inverse(float* a, int n) {
118
+ void dmat_inverse(double* a, int n) {
84
119
  int info;
85
120
  int lwork = n * n;
86
121
  std::vector<int> ipiv(n);
87
- std::vector<float> workspace(lwork);
122
+ std::vector<double> workspace(lwork);
88
123
 
89
- sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
124
+ dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
125
  FAISS_THROW_IF_NOT(info == 0);
91
- sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
126
+ dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
127
  FAISS_THROW_IF_NOT(info == 0);
93
128
  }
94
129
 
@@ -107,18 +142,15 @@ void random_int32(
107
142
 
108
143
  namespace faiss {
109
144
 
110
- LSQTimer lsq_timer;
111
-
112
- LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
- FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
-
115
- this->d = d;
116
- this->M = M;
117
- this->nbits = std::vector<size_t>(M, nbits);
118
-
119
- // set derived values
120
- set_derived_values();
145
+ lsq::LSQTimer lsq_timer;
146
+ using lsq::LSQTimerScope;
121
147
 
148
+ LocalSearchQuantizer::LocalSearchQuantizer(
149
+ size_t d,
150
+ size_t M,
151
+ size_t nbits,
152
+ Search_type_t search_type)
153
+ : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
122
154
  is_trained = false;
123
155
  verbose = false;
124
156
 
@@ -138,15 +170,23 @@ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
138
170
 
139
171
  random_seed = 0x12345;
140
172
  std::srand(random_seed);
173
+
174
+ icm_encoder_factory = nullptr;
141
175
  }
142
176
 
177
+ LocalSearchQuantizer::~LocalSearchQuantizer() {
178
+ delete icm_encoder_factory;
179
+ }
180
+
181
+ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
182
+
143
183
  void LocalSearchQuantizer::train(size_t n, const float* x) {
144
184
  FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
185
  FAISS_THROW_IF_NOT(nperts <= M);
146
186
 
147
187
  lsq_timer.reset();
188
+ LSQTimerScope scope(&lsq_timer, "train");
148
189
  if (verbose) {
149
- lsq_timer.start("train");
150
190
  printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
191
  M,
152
192
  n,
@@ -209,7 +249,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
209
249
  }
210
250
 
211
251
  // refine codes
212
- icm_encode(x, codes.data(), n, train_ils_iters, gen);
252
+ icm_encode(codes.data(), x, n, train_ils_iters, gen);
213
253
 
214
254
  if (verbose) {
215
255
  float obj = evaluate(codes.data(), x, n);
@@ -217,25 +257,52 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
217
257
  }
218
258
  }
219
259
 
260
+ is_trained = true;
261
+ {
262
+ std::vector<float> x_recons(n * d);
263
+ std::vector<float> norms(n);
264
+ decode_unpacked(codes.data(), x_recons.data(), n);
265
+ fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
266
+
267
+ norm_min = HUGE_VALF;
268
+ norm_max = -HUGE_VALF;
269
+ for (idx_t i = 0; i < n; i++) {
270
+ if (norms[i] < norm_min) {
271
+ norm_min = norms[i];
272
+ }
273
+ if (norms[i] > norm_max) {
274
+ norm_max = norms[i];
275
+ }
276
+ }
277
+
278
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
279
+ size_t k = (1 << 8);
280
+ if (search_type == ST_norm_cqint4) {
281
+ k = (1 << 4);
282
+ }
283
+ Clustering1D clus(k);
284
+ clus.train_exact(n, norms.data());
285
+ qnorm.add(clus.k, clus.centroids.data());
286
+ }
287
+ }
288
+
220
289
  if (verbose) {
221
- lsq_timer.end("train");
222
290
  float obj = evaluate(codes.data(), x, n);
291
+ scope.finish();
223
292
  printf("After training: obj = %lf\n", obj);
224
293
 
225
294
  printf("Time statistic:\n");
226
- for (const auto& it : lsq_timer.duration) {
227
- printf("\t%s time: %lf s\n", it.first.data(), it.second);
295
+ for (const auto& it : lsq_timer.t) {
296
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
228
297
  }
229
298
  }
230
-
231
- is_trained = true;
232
299
  }
233
300
 
234
301
  void LocalSearchQuantizer::perturb_codebooks(
235
302
  float T,
236
303
  const std::vector<float>& stddev,
237
304
  std::mt19937& gen) {
238
- lsq_timer.start("perturb_codebooks");
305
+ LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
239
306
 
240
307
  std::vector<std::normal_distribution<float>> distribs;
241
308
  for (size_t i = 0; i < d; i++) {
@@ -249,8 +316,6 @@ void LocalSearchQuantizer::perturb_codebooks(
249
316
  }
250
317
  }
251
318
  }
252
-
253
- lsq_timer.end("perturb_codebooks");
254
319
  }
255
320
 
256
321
  void LocalSearchQuantizer::compute_codes(
@@ -258,23 +323,26 @@ void LocalSearchQuantizer::compute_codes(
258
323
  uint8_t* codes_out,
259
324
  size_t n) const {
260
325
  FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
326
+
327
+ lsq_timer.reset();
328
+ LSQTimerScope scope(&lsq_timer, "encode");
261
329
  if (verbose) {
262
- lsq_timer.reset();
263
330
  printf("Encoding %zd vectors...\n", n);
264
- lsq_timer.start("encode");
265
331
  }
266
332
 
267
333
  std::vector<int32_t> codes(n * M);
268
334
  std::mt19937 gen(random_seed);
269
335
  random_int32(codes, 0, K - 1, gen);
270
336
 
271
- icm_encode(x, codes.data(), n, encode_ils_iters, gen);
337
+ icm_encode(codes.data(), x, n, encode_ils_iters, gen);
272
338
  pack_codes(n, codes.data(), codes_out);
273
339
 
274
340
  if (verbose) {
275
- lsq_timer.end("encode");
276
- double t = lsq_timer.get("encode");
277
- printf("Time to encode %zd vectors: %lf s\n", n, t);
341
+ scope.finish();
342
+ printf("Time statistic:\n");
343
+ for (const auto& it : lsq_timer.t) {
344
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
345
+ }
278
346
  }
279
347
  }
280
348
 
@@ -298,73 +366,144 @@ void LocalSearchQuantizer::update_codebooks(
298
366
  const float* x,
299
367
  const int32_t* codes,
300
368
  size_t n) {
301
- lsq_timer.start("update_codebooks");
369
+ LSQTimerScope scope(&lsq_timer, "update_codebooks");
370
+
371
+ if (!update_codebooks_with_double) {
372
+ // allocate memory
373
+ // bb = B'B, bx = BX
374
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
375
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
376
+
377
+ // compute B'B
378
+ for (size_t i = 0; i < n; i++) {
379
+ for (size_t m = 0; m < M; m++) {
380
+ int32_t code1 = codes[i * M + m];
381
+ int32_t idx1 = m * K + code1;
382
+ bb[idx1 * M * K + idx1] += 1;
383
+
384
+ for (size_t m2 = m + 1; m2 < M; m2++) {
385
+ int32_t code2 = codes[i * M + m2];
386
+ int32_t idx2 = m2 * K + code2;
387
+ bb[idx1 * M * K + idx2] += 1;
388
+ bb[idx2 * M * K + idx1] += 1;
389
+ }
390
+ }
391
+ }
302
392
 
303
- // allocate memory
304
- // bb = B'B, bx = BX
305
- std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
- std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
393
+ // add a regularization term to B'B
394
+ for (int64_t i = 0; i < M * K; i++) {
395
+ bb[i * (M * K) + i] += lambd;
396
+ }
307
397
 
308
- // compute B'B
309
- for (size_t i = 0; i < n; i++) {
310
- for (size_t m = 0; m < M; m++) {
311
- int32_t code1 = codes[i * M + m];
312
- int32_t idx1 = m * K + code1;
313
- bb[idx1 * M * K + idx1] += 1;
314
-
315
- for (size_t m2 = m + 1; m2 < M; m2++) {
316
- int32_t code2 = codes[i * M + m2];
317
- int32_t idx2 = m2 * K + code2;
318
- bb[idx1 * M * K + idx2] += 1;
319
- bb[idx2 * M * K + idx1] += 1;
398
+ // compute (B'B)^(-1)
399
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
400
+
401
+ // compute BX
402
+ for (size_t i = 0; i < n; i++) {
403
+ for (size_t m = 0; m < M; m++) {
404
+ int32_t code = codes[i * M + m];
405
+ float* data = bx.data() + (m * K + code) * d;
406
+ fvec_add(d, data, x + i * d, data);
320
407
  }
321
408
  }
322
- }
323
409
 
324
- // add a regularization term to B'B
325
- for (int64_t i = 0; i < M * K; i++) {
326
- bb[i * (M * K) + i] += lambd;
327
- }
410
+ // compute C = (B'B)^(-1) @ BX
411
+ //
412
+ // NOTE: LAPACK use column major order
413
+ // out = alpha * op(A) * op(B) + beta * C
414
+ FINTEGER nrows_A = d;
415
+ FINTEGER ncols_A = M * K;
416
+
417
+ FINTEGER nrows_B = M * K;
418
+ FINTEGER ncols_B = M * K;
419
+
420
+ float alpha = 1.0f;
421
+ float beta = 0.0f;
422
+ sgemm_("Not Transposed",
423
+ "Not Transposed",
424
+ &nrows_A, // nrows of op(A)
425
+ &ncols_B, // ncols of op(B)
426
+ &ncols_A, // ncols of op(A)
427
+ &alpha,
428
+ bx.data(),
429
+ &nrows_A, // nrows of A
430
+ bb.data(),
431
+ &nrows_B, // nrows of B
432
+ &beta,
433
+ codebooks.data(),
434
+ &nrows_A); // nrows of output
435
+
436
+ } else {
437
+ // allocate memory
438
+ // bb = B'B, bx = BX
439
+ std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
440
+ std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
441
+
442
+ // compute B'B
443
+ for (size_t i = 0; i < n; i++) {
444
+ for (size_t m = 0; m < M; m++) {
445
+ int32_t code1 = codes[i * M + m];
446
+ int32_t idx1 = m * K + code1;
447
+ bb[idx1 * M * K + idx1] += 1;
448
+
449
+ for (size_t m2 = m + 1; m2 < M; m2++) {
450
+ int32_t code2 = codes[i * M + m2];
451
+ int32_t idx2 = m2 * K + code2;
452
+ bb[idx1 * M * K + idx2] += 1;
453
+ bb[idx2 * M * K + idx1] += 1;
454
+ }
455
+ }
456
+ }
328
457
 
329
- // compute (B'B)^(-1)
330
- fmat_inverse(bb.data(), M * K); // [M*K, M*K]
458
+ // add a regularization term to B'B
459
+ for (int64_t i = 0; i < M * K; i++) {
460
+ bb[i * (M * K) + i] += lambd;
461
+ }
331
462
 
332
- // compute BX
333
- for (size_t i = 0; i < n; i++) {
334
- for (size_t m = 0; m < M; m++) {
335
- int32_t code = codes[i * M + m];
336
- float* data = bx.data() + (m * K + code) * d;
337
- fvec_add(d, data, x + i * d, data);
463
+ // compute (B'B)^(-1)
464
+ dmat_inverse(bb.data(), M * K); // [M*K, M*K]
465
+
466
+ // compute BX
467
+ for (size_t i = 0; i < n; i++) {
468
+ for (size_t m = 0; m < M; m++) {
469
+ int32_t code = codes[i * M + m];
470
+ double* data = bx.data() + (m * K + code) * d;
471
+ dfvec_add(d, data, x + i * d, data);
472
+ }
338
473
  }
339
- }
340
474
 
341
- // compute C = (B'B)^(-1) @ BX
342
- //
343
- // NOTE: LAPACK use column major order
344
- // out = alpha * op(A) * op(B) + beta * C
345
- FINTEGER nrows_A = d;
346
- FINTEGER ncols_A = M * K;
347
-
348
- FINTEGER nrows_B = M * K;
349
- FINTEGER ncols_B = M * K;
350
-
351
- float alpha = 1.0f;
352
- float beta = 0.0f;
353
- sgemm_("Not Transposed",
354
- "Not Transposed",
355
- &nrows_A, // nrows of op(A)
356
- &ncols_B, // ncols of op(B)
357
- &ncols_A, // ncols of op(A)
358
- &alpha,
359
- bx.data(),
360
- &nrows_A, // nrows of A
361
- bb.data(),
362
- &nrows_B, // nrows of B
363
- &beta,
364
- codebooks.data(),
365
- &nrows_A); // nrows of output
366
-
367
- lsq_timer.end("update_codebooks");
475
+ // compute C = (B'B)^(-1) @ BX
476
+ //
477
+ // NOTE: LAPACK use column major order
478
+ // out = alpha * op(A) * op(B) + beta * C
479
+ FINTEGER nrows_A = d;
480
+ FINTEGER ncols_A = M * K;
481
+
482
+ FINTEGER nrows_B = M * K;
483
+ FINTEGER ncols_B = M * K;
484
+
485
+ std::vector<double> d_codebooks(M * K * d);
486
+
487
+ double alpha = 1.0f;
488
+ double beta = 0.0f;
489
+ dgemm_("Not Transposed",
490
+ "Not Transposed",
491
+ &nrows_A, // nrows of op(A)
492
+ &ncols_B, // ncols of op(B)
493
+ &ncols_A, // ncols of op(A)
494
+ &alpha,
495
+ bx.data(),
496
+ &nrows_A, // nrows of A
497
+ bb.data(),
498
+ &nrows_B, // nrows of B
499
+ &beta,
500
+ d_codebooks.data(),
501
+ &nrows_A); // nrows of output
502
+
503
+ for (size_t i = 0; i < M * K * d; i++) {
504
+ codebooks[i] = (float)d_codebooks[i];
505
+ }
506
+ }
368
507
  }
369
508
 
370
509
  /** encode using iterative conditional mode
@@ -386,15 +525,23 @@ void LocalSearchQuantizer::update_codebooks(
386
525
  * These two terms can be precomputed and store in a look up table.
387
526
  */
388
527
  void LocalSearchQuantizer::icm_encode(
389
- const float* x,
390
528
  int32_t* codes,
529
+ const float* x,
391
530
  size_t n,
392
531
  size_t ils_iters,
393
532
  std::mt19937& gen) const {
394
- lsq_timer.start("icm_encode");
533
+ LSQTimerScope scope(&lsq_timer, "icm_encode");
534
+
535
+ auto factory = icm_encoder_factory;
536
+ std::unique_ptr<lsq::IcmEncoder> icm_encoder;
537
+ if (factory == nullptr) {
538
+ icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
539
+ } else {
540
+ icm_encoder.reset(factory->get(this));
541
+ }
395
542
 
396
- std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
- compute_binary_terms(binaries.data());
543
+ // precompute binary terms for all chunks
544
+ icm_encoder->set_binary_term();
398
545
 
399
546
  const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
547
  for (size_t i = 0; i < n_chunks; i++) {
@@ -410,21 +557,20 @@ void LocalSearchQuantizer::icm_encode(
410
557
 
411
558
  const float* xi = x + i * chunk_size * d;
412
559
  int32_t* codesi = codes + i * chunk_size * M;
413
- icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
560
+ icm_encoder->verbose = (verbose && i == 0);
561
+ icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
414
562
  }
415
-
416
- lsq_timer.end("icm_encode");
417
563
  }
418
564
 
419
- void LocalSearchQuantizer::icm_encode_partial(
420
- size_t index,
421
- const float* x,
565
+ void LocalSearchQuantizer::icm_encode_impl(
422
566
  int32_t* codes,
423
- size_t n,
567
+ const float* x,
424
568
  const float* binaries,
569
+ std::mt19937& gen,
570
+ size_t n,
425
571
  size_t ils_iters,
426
- std::mt19937& gen) const {
427
- std::vector<float> unaries(n * M * K); // [n, M, K]
572
+ bool verbose) const {
573
+ std::vector<float> unaries(n * M * K); // [M, n, K]
428
574
  compute_unary_terms(x, unaries.data(), n);
429
575
 
430
576
  std::vector<int32_t> best_codes;
@@ -438,9 +584,7 @@ void LocalSearchQuantizer::icm_encode_partial(
438
584
  // add perturbation to codes
439
585
  perturb_codes(codes, n, gen);
440
586
 
441
- for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
- icm_encode_step(unaries.data(), binaries, codes, n);
443
- }
587
+ icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
444
588
 
445
589
  std::vector<float> icm_objs(n, 0.0f);
446
590
  evaluate(codes, x, n, icm_objs.data());
@@ -463,7 +607,7 @@ void LocalSearchQuantizer::icm_encode_partial(
463
607
 
464
608
  memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
609
 
466
- if (verbose && index == 0) {
610
+ if (verbose) {
467
611
  printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
612
  iter1,
469
613
  mean_obj,
@@ -474,61 +618,67 @@ void LocalSearchQuantizer::icm_encode_partial(
474
618
  }
475
619
 
476
620
  void LocalSearchQuantizer::icm_encode_step(
621
+ int32_t* codes,
477
622
  const float* unaries,
478
623
  const float* binaries,
479
- int32_t* codes,
480
- size_t n) const {
481
- // condition on the m-th subcode
482
- for (size_t m = 0; m < M; m++) {
483
- std::vector<float> objs(n * K);
484
- #pragma omp parallel for
485
- for (int64_t i = 0; i < n; i++) {
486
- auto u = unaries + i * (M * K) + m * K;
487
- memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
- }
624
+ size_t n,
625
+ size_t n_iters) const {
626
+ FAISS_THROW_IF_NOT(M != 0 && K != 0);
627
+ FAISS_THROW_IF_NOT(binaries != nullptr);
489
628
 
490
- // compute objective function by adding unary
491
- // and binary terms together
492
- for (size_t other_m = 0; other_m < M; other_m++) {
493
- if (other_m == m) {
494
- continue;
629
+ for (size_t iter = 0; iter < n_iters; iter++) {
630
+ // condition on the m-th subcode
631
+ for (size_t m = 0; m < M; m++) {
632
+ std::vector<float> objs(n * K);
633
+ #pragma omp parallel for
634
+ for (int64_t i = 0; i < n; i++) {
635
+ auto u = unaries + m * n * K + i * K;
636
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
495
637
  }
496
638
 
639
+ // compute objective function by adding unary
640
+ // and binary terms together
641
+ for (size_t other_m = 0; other_m < M; other_m++) {
642
+ if (other_m == m) {
643
+ continue;
644
+ }
645
+
497
646
  #pragma omp parallel for
498
- for (int64_t i = 0; i < n; i++) {
499
- for (int32_t code = 0; code < K; code++) {
500
- int32_t code2 = codes[i * M + other_m];
501
- size_t binary_idx =
502
- m * M * K * K + other_m * K * K + code * K + code2;
503
- // binaries[m, other_m, code, code2]
504
- objs[i * K + code] += binaries[binary_idx];
647
+ for (int64_t i = 0; i < n; i++) {
648
+ for (int32_t code = 0; code < K; code++) {
649
+ int32_t code2 = codes[i * M + other_m];
650
+ size_t binary_idx = m * M * K * K + other_m * K * K +
651
+ code * K + code2;
652
+ // binaries[m, other_m, code, code2]
653
+ objs[i * K + code] += binaries[binary_idx];
654
+ }
505
655
  }
506
656
  }
507
- }
508
657
 
509
- // find the optimal value of the m-th subcode
658
+ // find the optimal value of the m-th subcode
510
659
  #pragma omp parallel for
511
- for (int64_t i = 0; i < n; i++) {
512
- float best_obj = HUGE_VALF;
513
- int32_t best_code = 0;
514
- for (size_t code = 0; code < K; code++) {
515
- float obj = objs[i * K + code];
516
- if (obj < best_obj) {
517
- best_obj = obj;
518
- best_code = code;
660
+ for (int64_t i = 0; i < n; i++) {
661
+ float best_obj = HUGE_VALF;
662
+ int32_t best_code = 0;
663
+ for (size_t code = 0; code < K; code++) {
664
+ float obj = objs[i * K + code];
665
+ if (obj < best_obj) {
666
+ best_obj = obj;
667
+ best_code = code;
668
+ }
519
669
  }
670
+ codes[i * M + m] = best_code;
520
671
  }
521
- codes[i * M + m] = best_code;
522
- }
523
672
 
524
- } // loop M
673
+ } // loop M
674
+ }
525
675
  }
526
676
 
527
677
  void LocalSearchQuantizer::perturb_codes(
528
678
  int32_t* codes,
529
679
  size_t n,
530
680
  std::mt19937& gen) const {
531
- lsq_timer.start("perturb_codes");
681
+ LSQTimerScope scope(&lsq_timer, "perturb_codes");
532
682
 
533
683
  std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
684
  std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
@@ -539,12 +689,10 @@ void LocalSearchQuantizer::perturb_codes(
539
689
  codes[i * M + m] = k_distrib(gen);
540
690
  }
541
691
  }
542
-
543
- lsq_timer.end("perturb_codes");
544
692
  }
545
693
 
546
694
  void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
- lsq_timer.start("compute_binary_terms");
695
+ LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
548
696
 
549
697
  #pragma omp parallel for
550
698
  for (int64_t m12 = 0; m12 < M * M; m12++) {
@@ -562,52 +710,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
562
710
  }
563
711
  }
564
712
  }
565
-
566
- lsq_timer.end("compute_binary_terms");
567
713
  }
568
714
 
569
715
  void LocalSearchQuantizer::compute_unary_terms(
570
716
  const float* x,
571
- float* unaries,
717
+ float* unaries, // [M, n, K]
572
718
  size_t n) const {
573
- lsq_timer.start("compute_unary_terms");
719
+ LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
574
720
 
575
- // compute x * codebooks^T
721
+ // compute x * codebook^T for each codebook
576
722
  //
577
723
  // NOTE: LAPACK use column major order
578
724
  // out = alpha * op(A) * op(B) + beta * C
579
- FINTEGER nrows_A = M * K;
580
- FINTEGER ncols_A = d;
581
-
582
- FINTEGER nrows_B = d;
583
- FINTEGER ncols_B = n;
584
-
585
- float alpha = -2.0f;
586
- float beta = 0.0f;
587
- sgemm_("Transposed",
588
- "Not Transposed",
589
- &nrows_A, // nrows of op(A)
590
- &ncols_B, // ncols of op(B)
591
- &ncols_A, // ncols of op(A)
592
- &alpha,
593
- codebooks.data(),
594
- &ncols_A, // nrows of A
595
- x,
596
- &nrows_B, // nrows of B
597
- &beta,
598
- unaries,
599
- &nrows_A); // nrows of output
725
+
726
+ for (size_t m = 0; m < M; m++) {
727
+ FINTEGER nrows_A = K;
728
+ FINTEGER ncols_A = d;
729
+
730
+ FINTEGER nrows_B = d;
731
+ FINTEGER ncols_B = n;
732
+
733
+ float alpha = -2.0f;
734
+ float beta = 0.0f;
735
+ sgemm_("Transposed",
736
+ "Not Transposed",
737
+ &nrows_A, // nrows of op(A)
738
+ &ncols_B, // ncols of op(B)
739
+ &ncols_A, // ncols of op(A)
740
+ &alpha,
741
+ codebooks.data() + m * K * d,
742
+ &ncols_A, // nrows of A
743
+ x,
744
+ &nrows_B, // nrows of B
745
+ &beta,
746
+ unaries + m * n * K,
747
+ &nrows_A); // nrows of output
748
+ }
600
749
 
601
750
  std::vector<float> norms(M * K);
602
751
  fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
752
 
604
753
  #pragma omp parallel for
605
754
  for (int64_t i = 0; i < n; i++) {
606
- float* u = unaries + i * (M * K);
607
- fvec_add(M * K, u, norms.data(), u);
755
+ for (size_t m = 0; m < M; m++) {
756
+ float* u = unaries + m * n * K + i * K;
757
+ fvec_add(K, u, norms.data() + m * K, u);
758
+ }
608
759
  }
609
-
610
- lsq_timer.end("compute_unary_terms");
611
760
  }
612
761
 
613
762
  float LocalSearchQuantizer::evaluate(
@@ -615,7 +764,7 @@ float LocalSearchQuantizer::evaluate(
615
764
  const float* x,
616
765
  size_t n,
617
766
  float* objs) const {
618
- lsq_timer.start("evaluate");
767
+ LSQTimerScope scope(&lsq_timer, "evaluate");
619
768
 
620
769
  // decode
621
770
  std::vector<float> decoded_x(n * d, 0.0f);
@@ -631,7 +780,7 @@ float LocalSearchQuantizer::evaluate(
631
780
  fvec_add(d, decoded_i, c, decoded_i);
632
781
  }
633
782
 
634
- float err = fvec_L2sqr(x + i * d, decoded_i, d);
783
+ float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
635
784
  obj += err;
636
785
 
637
786
  if (objs) {
@@ -639,34 +788,68 @@ float LocalSearchQuantizer::evaluate(
639
788
  }
640
789
  }
641
790
 
642
- lsq_timer.end("evaluate");
643
-
644
791
  obj = obj / n;
645
792
  return obj;
646
793
  }
647
794
 
648
- double LSQTimer::get(const std::string& name) {
649
- return duration[name];
795
+ namespace lsq {
796
+
797
+ IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
798
+ : verbose(false), lsq(lsq) {}
799
+
800
+ void IcmEncoder::set_binary_term() {
801
+ auto M = lsq->M;
802
+ auto K = lsq->K;
803
+ binaries.resize(M * M * K * K);
804
+ lsq->compute_binary_terms(binaries.data());
650
805
  }
651
806
 
652
- void LSQTimer::start(const std::string& name) {
653
- FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
- started[name] = true;
655
- t0[name] = getmillisecs();
807
+ void IcmEncoder::encode(
808
+ int32_t* codes,
809
+ const float* x,
810
+ std::mt19937& gen,
811
+ size_t n,
812
+ size_t ils_iters) const {
813
+ lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
656
814
  }
657
815
 
658
- void LSQTimer::end(const std::string& name) {
659
- FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
- double t1 = getmillisecs();
661
- double sec = (t1 - t0[name]) / 1000;
662
- duration[name] += sec;
663
- started[name] = false;
816
+ double LSQTimer::get(const std::string& name) {
817
+ if (t.count(name) == 0) {
818
+ return 0.0;
819
+ } else {
820
+ return t[name];
821
+ }
822
+ }
823
+
824
+ void LSQTimer::add(const std::string& name, double delta) {
825
+ if (t.count(name) == 0) {
826
+ t[name] = delta;
827
+ } else {
828
+ t[name] += delta;
829
+ }
664
830
  }
665
831
 
666
832
  void LSQTimer::reset() {
667
- duration.clear();
668
- t0.clear();
669
- started.clear();
833
+ t.clear();
834
+ }
835
+
836
+ LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
837
+ : timer(timer), name(name), finished(false) {
838
+ t0 = getmillisecs();
670
839
  }
671
840
 
841
+ void LSQTimerScope::finish() {
842
+ if (!finished) {
843
+ auto delta = getmillisecs() - t0;
844
+ timer->add(name, delta);
845
+ finished = true;
846
+ }
847
+ }
848
+
849
+ LSQTimerScope::~LSQTimerScope() {
850
+ finish();
851
+ }
852
+
853
+ } // namespace lsq
854
+
672
855
  } // namespace faiss