faiss 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -25,9 +25,11 @@
25
25
  #include <faiss/invlists/InvertedListsIOHook.h>
26
26
 
27
27
  #include <faiss/Index2Layer.h>
28
+ #include <faiss/IndexAdditiveQuantizer.h>
28
29
  #include <faiss/IndexFlat.h>
29
30
  #include <faiss/IndexHNSW.h>
30
31
  #include <faiss/IndexIVF.h>
32
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
31
33
  #include <faiss/IndexIVFFlat.h>
32
34
  #include <faiss/IndexIVFPQ.h>
33
35
  #include <faiss/IndexIVFPQFastScan.h>
@@ -40,7 +42,6 @@
40
42
  #include <faiss/IndexPQFastScan.h>
41
43
  #include <faiss/IndexPreTransform.h>
42
44
  #include <faiss/IndexRefine.h>
43
- #include <faiss/IndexResidual.h>
44
45
  #include <faiss/IndexScalarQuantizer.h>
45
46
  #include <faiss/MetaIndexes.h>
46
47
  #include <faiss/VectorTransform.h>
@@ -77,16 +78,22 @@ VectorTransform* read_VectorTransform(IOReader* f) {
77
78
  VectorTransform* vt = nullptr;
78
79
 
79
80
  if (h == fourcc("rrot") || h == fourcc("PCAm") || h == fourcc("LTra") ||
80
- h == fourcc("PcAm") || h == fourcc("Viqm")) {
81
+ h == fourcc("PcAm") || h == fourcc("Viqm") || h == fourcc("Pcam")) {
81
82
  LinearTransform* lt = nullptr;
82
83
  if (h == fourcc("rrot")) {
83
84
  lt = new RandomRotationMatrix();
84
- } else if (h == fourcc("PCAm") || h == fourcc("PcAm")) {
85
+ } else if (
86
+ h == fourcc("PCAm") || h == fourcc("PcAm") ||
87
+ h == fourcc("Pcam")) {
85
88
  PCAMatrix* pca = new PCAMatrix();
86
89
  READ1(pca->eigen_power);
90
+ if (h == fourcc("Pcam")) {
91
+ READ1(pca->epsilon);
92
+ }
87
93
  READ1(pca->random_rotation);
88
- if (h == fourcc("PcAm"))
94
+ if (h != fourcc("PCAm")) {
89
95
  READ1(pca->balanced_bins);
96
+ }
90
97
  READVECTOR(pca->mean);
91
98
  READVECTOR(pca->eigenvalues);
92
99
  READVECTOR(pca->PCAMat);
@@ -139,9 +146,10 @@ VectorTransform* read_VectorTransform(IOReader* f) {
139
146
  vt = itqt;
140
147
  } else {
141
148
  FAISS_THROW_FMT(
142
- "fourcc %ud (\"%s\") not recognized",
149
+ "fourcc %ud (\"%s\") not recognized in %s",
143
150
  h,
144
- fourcc_inv_printable(h).c_str());
151
+ fourcc_inv_printable(h).c_str(),
152
+ f->name.c_str());
145
153
  }
146
154
  READ1(vt->d_in);
147
155
  READ1(vt->d_out);
@@ -239,15 +247,58 @@ static void read_ProductQuantizer(ProductQuantizer* pq, IOReader* f) {
239
247
  READVECTOR(pq->centroids);
240
248
  }
241
249
 
242
- static void read_ResidualQuantizer(ResidualQuantizer* rq, IOReader* f) {
250
+ static void read_ResidualQuantizer_old(ResidualQuantizer* rq, IOReader* f) {
243
251
  READ1(rq->d);
244
252
  READ1(rq->M);
245
253
  READVECTOR(rq->nbits);
246
- rq->set_derived_values();
247
254
  READ1(rq->is_trained);
248
255
  READ1(rq->train_type);
249
256
  READ1(rq->max_beam_size);
250
257
  READVECTOR(rq->codebooks);
258
+ READ1(rq->search_type);
259
+ READ1(rq->norm_min);
260
+ READ1(rq->norm_max);
261
+ rq->set_derived_values();
262
+ }
263
+
264
+ static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) {
265
+ READ1(aq->d);
266
+ READ1(aq->M);
267
+ READVECTOR(aq->nbits);
268
+ READ1(aq->is_trained);
269
+ READVECTOR(aq->codebooks);
270
+ READ1(aq->search_type);
271
+ READ1(aq->norm_min);
272
+ READ1(aq->norm_max);
273
+ if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8 ||
274
+ aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
275
+ READXBVECTOR(aq->qnorm.codes);
276
+ }
277
+ aq->set_derived_values();
278
+ }
279
+
280
+ static void read_ResidualQuantizer(ResidualQuantizer* rq, IOReader* f) {
281
+ read_AdditiveQuantizer(rq, f);
282
+ READ1(rq->train_type);
283
+ READ1(rq->max_beam_size);
284
+ if (!(rq->train_type & ResidualQuantizer::Skip_codebook_tables)) {
285
+ rq->compute_codebook_tables();
286
+ }
287
+ }
288
+
289
+ static void read_LocalSearchQuantizer(LocalSearchQuantizer* lsq, IOReader* f) {
290
+ read_AdditiveQuantizer(lsq, f);
291
+ READ1(lsq->K);
292
+ READ1(lsq->train_iters);
293
+ READ1(lsq->encode_ils_iters);
294
+ READ1(lsq->train_ils_iters);
295
+ READ1(lsq->icm_iters);
296
+ READ1(lsq->p);
297
+ READ1(lsq->lambd);
298
+ READ1(lsq->chunk_size);
299
+ READ1(lsq->random_seed);
300
+ READ1(lsq->nperts);
301
+ READ1(lsq->update_codebooks_with_double);
251
302
  }
252
303
 
253
304
  static void read_ScalarQuantizer(ScalarQuantizer* ivsc, IOReader* f) {
@@ -422,8 +473,10 @@ Index* read_index(IOReader* f, int io_flags) {
422
473
  idxf = new IndexFlat();
423
474
  }
424
475
  read_index_header(idxf, f);
425
- READVECTOR(idxf->xb);
426
- FAISS_THROW_IF_NOT(idxf->xb.size() == idxf->ntotal * idxf->d);
476
+ idxf->code_size = idxf->d * sizeof(float);
477
+ READXBVECTOR(idxf->codes);
478
+ FAISS_THROW_IF_NOT(
479
+ idxf->codes.size() == idxf->ntotal * idxf->code_size);
427
480
  // leak!
428
481
  idx = idxf;
429
482
  } else if (h == fourcc("IxHE") || h == fourcc("IxHe")) {
@@ -433,7 +486,9 @@ Index* read_index(IOReader* f, int io_flags) {
433
486
  READ1(idxl->rotate_data);
434
487
  READ1(idxl->train_thresholds);
435
488
  READVECTOR(idxl->thresholds);
436
- READ1(idxl->bytes_per_vec);
489
+ int code_size_i;
490
+ READ1(code_size_i);
491
+ idxl->code_size = code_size_i;
437
492
  if (h == fourcc("IxHE")) {
438
493
  FAISS_THROW_IF_NOT_FMT(
439
494
  idxl->nbits % 64 == 0,
@@ -441,7 +496,7 @@ Index* read_index(IOReader* f, int io_flags) {
441
496
  "nbits multiple of 64 (got %d)",
442
497
  (int)idxl->nbits);
443
498
  // leak
444
- idxl->bytes_per_vec *= 8;
499
+ idxl->code_size *= 8;
445
500
  }
446
501
  {
447
502
  RandomRotationMatrix* rrot = dynamic_cast<RandomRotationMatrix*>(
@@ -454,7 +509,7 @@ Index* read_index(IOReader* f, int io_flags) {
454
509
  FAISS_THROW_IF_NOT(
455
510
  idxl->rrot.d_in == idxl->d && idxl->rrot.d_out == idxl->nbits);
456
511
  FAISS_THROW_IF_NOT(
457
- idxl->codes.size() == idxl->ntotal * idxl->bytes_per_vec);
512
+ idxl->codes.size() == idxl->ntotal * idxl->code_size);
458
513
  idx = idxl;
459
514
  } else if (
460
515
  h == fourcc("IxPQ") || h == fourcc("IxPo") || h == fourcc("IxPq")) {
@@ -462,6 +517,7 @@ Index* read_index(IOReader* f, int io_flags) {
462
517
  IndexPQ* idxp = new IndexPQ();
463
518
  read_index_header(idxp, f);
464
519
  read_ProductQuantizer(&idxp->pq, f);
520
+ idxp->code_size = idxp->pq.code_size;
465
521
  READVECTOR(idxp->codes);
466
522
  if (h == fourcc("IxPo") || h == fourcc("IxPq")) {
467
523
  READ1(idxp->search_type);
@@ -475,13 +531,21 @@ Index* read_index(IOReader* f, int io_flags) {
475
531
  idxp->metric_type = METRIC_L2;
476
532
  }
477
533
  idx = idxp;
478
- } else if (h == fourcc("IxRQ")) {
479
- IndexResidual* idxr = new IndexResidual();
534
+ } else if (h == fourcc("IxRQ") || h == fourcc("IxRq")) {
535
+ IndexResidualQuantizer* idxr = new IndexResidualQuantizer();
480
536
  read_index_header(idxr, f);
481
- read_ResidualQuantizer(&idxr->rq, f);
482
- READ1(idxr->search_type);
483
- READ1(idxr->norm_min);
484
- READ1(idxr->norm_max);
537
+ if (h == fourcc("IxRQ")) {
538
+ read_ResidualQuantizer_old(&idxr->rq, f);
539
+ } else {
540
+ read_ResidualQuantizer(&idxr->rq, f);
541
+ }
542
+ READ1(idxr->code_size);
543
+ READVECTOR(idxr->codes);
544
+ idx = idxr;
545
+ } else if (h == fourcc("IxLS")) {
546
+ auto idxr = new IndexLocalSearchQuantizer();
547
+ read_index_header(idxr, f);
548
+ read_LocalSearchQuantizer(&idxr->lsq, f);
485
549
  READ1(idxr->code_size);
486
550
  READVECTOR(idxr->codes);
487
551
  idx = idxr;
@@ -571,6 +635,25 @@ Index* read_index(IOReader* f, int io_flags) {
571
635
  }
572
636
  read_InvertedLists(ivsc, f, io_flags);
573
637
  idx = ivsc;
638
+ } else if (h == fourcc("IwLS") || h == fourcc("IwRQ")) {
639
+ bool is_LSQ = h == fourcc("IwLS");
640
+ IndexIVFAdditiveQuantizer* iva;
641
+ if (is_LSQ) {
642
+ iva = new IndexIVFLocalSearchQuantizer();
643
+ } else {
644
+ iva = new IndexIVFResidualQuantizer();
645
+ }
646
+ read_ivf_header(iva, f);
647
+ READ1(iva->code_size);
648
+ if (is_LSQ) {
649
+ read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
650
+ } else {
651
+ read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
652
+ }
653
+ READ1(iva->by_residual);
654
+ READ1(iva->use_precomputed_table);
655
+ read_InvertedLists(iva, f, io_flags);
656
+ idx = iva;
574
657
  } else if (h == fourcc("IwSh")) {
575
658
  IndexIVFSpectralHash* ivsp = new IndexIVFSpectralHash();
576
659
  read_ivf_header(ivsp, f);
@@ -26,9 +26,11 @@
26
26
  #include <faiss/utils/hamming.h>
27
27
 
28
28
  #include <faiss/Index2Layer.h>
29
+ #include <faiss/IndexAdditiveQuantizer.h>
29
30
  #include <faiss/IndexFlat.h>
30
31
  #include <faiss/IndexHNSW.h>
31
32
  #include <faiss/IndexIVF.h>
33
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
32
34
  #include <faiss/IndexIVFFlat.h>
33
35
  #include <faiss/IndexIVFPQ.h>
34
36
  #include <faiss/IndexIVFPQFastScan.h>
@@ -41,7 +43,6 @@
41
43
  #include <faiss/IndexPQFastScan.h>
42
44
  #include <faiss/IndexPreTransform.h>
43
45
  #include <faiss/IndexRefine.h>
44
- #include <faiss/IndexResidual.h>
45
46
  #include <faiss/IndexScalarQuantizer.h>
46
47
  #include <faiss/MetaIndexes.h>
47
48
  #include <faiss/VectorTransform.h>
@@ -95,9 +96,10 @@ void write_VectorTransform(const VectorTransform* vt, IOWriter* f) {
95
96
  uint32_t h = fourcc("rrot");
96
97
  WRITE1(h);
97
98
  } else if (const PCAMatrix* pca = dynamic_cast<const PCAMatrix*>(lt)) {
98
- uint32_t h = fourcc("PcAm");
99
+ uint32_t h = fourcc("Pcam");
99
100
  WRITE1(h);
100
101
  WRITE1(pca->eigen_power);
102
+ WRITE1(pca->epsilon);
101
103
  WRITE1(pca->random_rotation);
102
104
  WRITE1(pca->balanced_bins);
103
105
  WRITEVECTOR(pca->mean);
@@ -158,14 +160,42 @@ void write_ProductQuantizer(const ProductQuantizer* pq, IOWriter* f) {
158
160
  WRITEVECTOR(pq->centroids);
159
161
  }
160
162
 
161
- void write_ResidualQuantizer(const ResidualQuantizer* rq, IOWriter* f) {
162
- WRITE1(rq->d);
163
- WRITE1(rq->M);
164
- WRITEVECTOR(rq->nbits);
165
- WRITE1(rq->is_trained);
163
+ static void write_AdditiveQuantizer(const AdditiveQuantizer* aq, IOWriter* f) {
164
+ WRITE1(aq->d);
165
+ WRITE1(aq->M);
166
+ WRITEVECTOR(aq->nbits);
167
+ WRITE1(aq->is_trained);
168
+ WRITEVECTOR(aq->codebooks);
169
+ WRITE1(aq->search_type);
170
+ WRITE1(aq->norm_min);
171
+ WRITE1(aq->norm_max);
172
+ if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8 ||
173
+ aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
174
+ WRITEXBVECTOR(aq->qnorm.codes);
175
+ }
176
+ }
177
+
178
+ static void write_ResidualQuantizer(const ResidualQuantizer* rq, IOWriter* f) {
179
+ write_AdditiveQuantizer(rq, f);
166
180
  WRITE1(rq->train_type);
167
181
  WRITE1(rq->max_beam_size);
168
- WRITEVECTOR(rq->codebooks);
182
+ }
183
+
184
+ static void write_LocalSearchQuantizer(
185
+ const LocalSearchQuantizer* lsq,
186
+ IOWriter* f) {
187
+ write_AdditiveQuantizer(lsq, f);
188
+ WRITE1(lsq->K);
189
+ WRITE1(lsq->train_iters);
190
+ WRITE1(lsq->encode_ils_iters);
191
+ WRITE1(lsq->train_ils_iters);
192
+ WRITE1(lsq->icm_iters);
193
+ WRITE1(lsq->p);
194
+ WRITE1(lsq->lambd);
195
+ WRITE1(lsq->chunk_size);
196
+ WRITE1(lsq->random_seed);
197
+ WRITE1(lsq->nperts);
198
+ WRITE1(lsq->update_codebooks_with_double);
169
199
  }
170
200
 
171
201
  static void write_ScalarQuantizer(const ScalarQuantizer* ivsc, IOWriter* f) {
@@ -315,7 +345,7 @@ void write_index(const Index* idx, IOWriter* f) {
315
345
  : "IxFl");
316
346
  WRITE1(h);
317
347
  write_index_header(idx, f);
318
- WRITEVECTOR(idxf->xb);
348
+ WRITEXBVECTOR(idxf->codes);
319
349
  } else if (const IndexLSH* idxl = dynamic_cast<const IndexLSH*>(idx)) {
320
350
  uint32_t h = fourcc("IxHe");
321
351
  WRITE1(h);
@@ -324,7 +354,8 @@ void write_index(const Index* idx, IOWriter* f) {
324
354
  WRITE1(idxl->rotate_data);
325
355
  WRITE1(idxl->train_thresholds);
326
356
  WRITEVECTOR(idxl->thresholds);
327
- WRITE1(idxl->bytes_per_vec);
357
+ int code_size_i = idxl->code_size;
358
+ WRITE1(code_size_i);
328
359
  write_VectorTransform(&idxl->rrot, f);
329
360
  WRITEVECTOR(idxl->codes);
330
361
  } else if (const IndexPQ* idxp = dynamic_cast<const IndexPQ*>(idx)) {
@@ -338,15 +369,20 @@ void write_index(const Index* idx, IOWriter* f) {
338
369
  WRITE1(idxp->encode_signs);
339
370
  WRITE1(idxp->polysemous_ht);
340
371
  } else if (
341
- const IndexResidual* idxr =
342
- dynamic_cast<const IndexResidual*>(idx)) {
343
- uint32_t h = fourcc("IxRQ");
372
+ const IndexResidualQuantizer* idxr =
373
+ dynamic_cast<const IndexResidualQuantizer*>(idx)) {
374
+ uint32_t h = fourcc("IxRq");
344
375
  WRITE1(h);
345
376
  write_index_header(idx, f);
346
377
  write_ResidualQuantizer(&idxr->rq, f);
347
- WRITE1(idxr->search_type);
348
- WRITE1(idxr->norm_min);
349
- WRITE1(idxr->norm_max);
378
+ WRITE1(idxr->code_size);
379
+ WRITEVECTOR(idxr->codes);
380
+ } else if (
381
+ auto* idxr = dynamic_cast<const IndexLocalSearchQuantizer*>(idx)) {
382
+ uint32_t h = fourcc("IxLS");
383
+ WRITE1(h);
384
+ write_index_header(idx, f);
385
+ write_LocalSearchQuantizer(&idxr->lsq, f);
350
386
  WRITE1(idxr->code_size);
351
387
  WRITEVECTOR(idxr->codes);
352
388
  } else if (
@@ -421,6 +457,20 @@ void write_index(const Index* idx, IOWriter* f) {
421
457
  WRITE1(ivsc->code_size);
422
458
  WRITE1(ivsc->by_residual);
423
459
  write_InvertedLists(ivsc->invlists, f);
460
+ } else if (auto iva = dynamic_cast<const IndexIVFAdditiveQuantizer*>(idx)) {
461
+ bool is_LSQ = dynamic_cast<const IndexIVFLocalSearchQuantizer*>(iva);
462
+ uint32_t h = fourcc(is_LSQ ? "IwLS" : "IwRQ");
463
+ WRITE1(h);
464
+ write_ivf_header(iva, f);
465
+ WRITE1(iva->code_size);
466
+ if (is_LSQ) {
467
+ write_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
468
+ } else {
469
+ write_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
470
+ }
471
+ WRITE1(iva->by_residual);
472
+ WRITE1(iva->use_precomputed_table);
473
+ write_InvertedLists(iva->invlists, f);
424
474
  } else if (
425
475
  const IndexIVFSpectralHash* ivsp =
426
476
  dynamic_cast<const IndexIVFSpectralHash*>(idx)) {
@@ -240,7 +240,7 @@ uint32_t fourcc(const std::string& sx) {
240
240
 
241
241
  void fourcc_inv(uint32_t x, char str[5]) {
242
242
  *(uint32_t*)str = x;
243
- str[5] = 0;
243
+ str[4] = 0;
244
244
  }
245
245
 
246
246
  std::string fourcc_inv(uint32_t x) {
@@ -66,3 +66,23 @@
66
66
  WRITEANDCHECK(&size, 1); \
67
67
  WRITEANDCHECK((vec).data(), size); \
68
68
  }
69
+
70
+ // read/write xb vector for backwards compatibility of IndexFlat
71
+
72
+ #define WRITEXBVECTOR(vec) \
73
+ { \
74
+ FAISS_THROW_IF_NOT((vec).size() % 4 == 0); \
75
+ size_t size = (vec).size() / 4; \
76
+ WRITEANDCHECK(&size, 1); \
77
+ WRITEANDCHECK((vec).data(), size * 4); \
78
+ }
79
+
80
+ #define READXBVECTOR(vec) \
81
+ { \
82
+ size_t size; \
83
+ READANDCHECK(&size, 1); \
84
+ FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \
85
+ size *= 4; \
86
+ (vec).resize(size); \
87
+ READANDCHECK((vec).data(), size); \
88
+ }
@@ -0,0 +1,301 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <algorithm>
9
+ #include <cstdint>
10
+ #include <cstring>
11
+ #include <functional>
12
+ #include <numeric>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/Index.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/kmeans1d.h>
20
+
21
+ namespace faiss {
22
+
23
+ using idx_t = Index::idx_t;
24
+ using LookUpFunc = std::function<float(idx_t, idx_t)>;
25
+
26
+ void reduce(
27
+ const std::vector<idx_t>& rows,
28
+ const std::vector<idx_t>& input_cols,
29
+ const LookUpFunc& lookup,
30
+ std::vector<idx_t>& output_cols) {
31
+ for (idx_t col : input_cols) {
32
+ while (!output_cols.empty()) {
33
+ idx_t row = rows[output_cols.size() - 1];
34
+ float a = lookup(row, col);
35
+ float b = lookup(row, output_cols.back());
36
+ if (a >= b) { // defeated
37
+ break;
38
+ }
39
+ output_cols.pop_back();
40
+ }
41
+ if (output_cols.size() < rows.size()) {
42
+ output_cols.push_back(col);
43
+ }
44
+ }
45
+ }
46
+
47
+ void interpolate(
48
+ const std::vector<idx_t>& rows,
49
+ const std::vector<idx_t>& cols,
50
+ const LookUpFunc& lookup,
51
+ idx_t* argmins) {
52
+ std::unordered_map<idx_t, idx_t> idx_to_col;
53
+ for (idx_t idx = 0; idx < cols.size(); ++idx) {
54
+ idx_to_col[cols[idx]] = idx;
55
+ }
56
+
57
+ idx_t start = 0;
58
+ for (idx_t r = 0; r < rows.size(); r += 2) {
59
+ idx_t row = rows[r];
60
+ idx_t end = cols.size() - 1;
61
+ if (r < rows.size() - 1) {
62
+ idx_t idx = argmins[rows[r + 1]];
63
+ end = idx_to_col[idx];
64
+ }
65
+ idx_t argmin = cols[start];
66
+ float min = lookup(row, argmin);
67
+ for (idx_t c = start + 1; c <= end; c++) {
68
+ float value = lookup(row, cols[c]);
69
+ if (value < min) {
70
+ argmin = cols[c];
71
+ min = value;
72
+ }
73
+ }
74
+ argmins[row] = argmin;
75
+ start = end;
76
+ }
77
+ }
78
+
79
+ /** SMAWK algo. Find the row minima of a monotone matrix.
80
+ *
81
+ * References:
82
+ * 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf
83
+ * 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405
84
+ * 3. https://github.com/dstein64/kmeans1d
85
+ */
86
+ void smawk_impl(
87
+ const std::vector<idx_t>& rows,
88
+ const std::vector<idx_t>& input_cols,
89
+ const LookUpFunc& lookup,
90
+ idx_t* argmins) {
91
+ if (rows.size() == 0) {
92
+ return;
93
+ }
94
+
95
+ /**********************************
96
+ * REDUCE
97
+ **********************************/
98
+ auto ptr = &input_cols;
99
+ std::vector<idx_t> survived_cols; // survived columns
100
+ if (rows.size() < input_cols.size()) {
101
+ reduce(rows, input_cols, lookup, survived_cols);
102
+ ptr = &survived_cols;
103
+ }
104
+ auto& cols = *ptr; // avoid memory copy
105
+
106
+ /**********************************
107
+ * INTERPOLATE
108
+ **********************************/
109
+
110
+ // call recursively on odd-indexed rows
111
+ std::vector<idx_t> odd_rows;
112
+ for (idx_t i = 1; i < rows.size(); i += 2) {
113
+ odd_rows.push_back(rows[i]);
114
+ }
115
+ smawk_impl(odd_rows, cols, lookup, argmins);
116
+
117
+ // interpolate the even-indexed rows
118
+ interpolate(rows, cols, lookup, argmins);
119
+ }
120
+
121
+ void smawk(
122
+ const idx_t nrows,
123
+ const idx_t ncols,
124
+ const LookUpFunc& lookup,
125
+ idx_t* argmins) {
126
+ std::vector<idx_t> rows(nrows);
127
+ std::vector<idx_t> cols(ncols);
128
+ std::iota(std::begin(rows), std::end(rows), 0);
129
+ std::iota(std::begin(cols), std::end(cols), 0);
130
+
131
+ smawk_impl(rows, cols, lookup, argmins);
132
+ }
133
+
134
+ void smawk(
135
+ const idx_t nrows,
136
+ const idx_t ncols,
137
+ const float* x,
138
+ idx_t* argmins) {
139
+ auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; };
140
+ smawk(nrows, ncols, lookup, argmins);
141
+ }
142
+
143
+ namespace {
144
+
145
+ class CostCalculator {
146
+ // The reuslt would be inaccurate if we use float
147
+ std::vector<double> cumsum;
148
+ std::vector<double> cumsum2;
149
+
150
+ public:
151
+ CostCalculator(const std::vector<float>& vec, idx_t n) {
152
+ cumsum.push_back(0.0);
153
+ cumsum2.push_back(0.0);
154
+ for (idx_t i = 0; i < n; ++i) {
155
+ float x = vec[i];
156
+ cumsum.push_back(x + cumsum[i]);
157
+ cumsum2.push_back(x * x + cumsum2[i]);
158
+ }
159
+ }
160
+
161
+ float operator()(idx_t i, idx_t j) {
162
+ if (j < i) {
163
+ return 0.0f;
164
+ }
165
+ auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
166
+ auto result = cumsum2[j + 1] - cumsum2[i];
167
+ result += (j - i + 1) * (mu * mu);
168
+ result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
169
+ return float(result);
170
+ }
171
+ };
172
+
173
+ template <class T>
174
+ class Matrix {
175
+ std::vector<T> data;
176
+ idx_t nrows;
177
+ idx_t ncols;
178
+
179
+ public:
180
+ Matrix(idx_t nrows, idx_t ncols) {
181
+ this->nrows = nrows;
182
+ this->ncols = ncols;
183
+ data.resize(nrows * ncols);
184
+ }
185
+
186
+ inline T& at(idx_t i, idx_t j) {
187
+ return data[i * ncols + j];
188
+ }
189
+ };
190
+
191
+ } // anonymous namespace
192
+
193
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
194
+ FAISS_THROW_IF_NOT(n >= nclusters);
195
+
196
+ // corner case
197
+ if (n == nclusters) {
198
+ memcpy(centroids, x, n * sizeof(*x));
199
+ return 0.0f;
200
+ }
201
+
202
+ /***************************************************
203
+ * sort in ascending order, O(NlogN) in time
204
+ ***************************************************/
205
+ std::vector<float> arr(x, x + n);
206
+ std::sort(arr.begin(), arr.end());
207
+
208
+ /***************************************************
209
+ dynamic programming algorithm
210
+
211
+ Reference: https://arxiv.org/abs/1701.07204
212
+ -------------------------------
213
+
214
+ Assume x is already sorted in ascending order.
215
+
216
+ N: number of points
217
+ K: number of clusters
218
+
219
+ CC(i, j): the cost of grouping xi,...,xj into one cluster
220
+ D[k][m]: the cost of optimally clustering x1,...,xm into k clusters
221
+ T[k][m]: the start index of the k-th cluster
222
+
223
+ The DP process is as follow:
224
+ D[k][m] = min_i D[k − 1][i − 1] + CC(i, m)
225
+ T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m)
226
+
227
+ This could be solved in O(KN^2) time and O(KN) space.
228
+
229
+ To further reduce the time complexity, we use SMAWK algo to
230
+ solve the argmin problem as follow:
231
+
232
+ For each k:
233
+ C[m][i] = D[k − 1][i − 1] + CC(i, m)
234
+
235
+ Here C is a n x n totally monotone matrix.
236
+ We could find the row minima by SMAWK in O(N) time.
237
+
238
+ Now the time complexity is reduced from O(kN^2) to O(KN).
239
+ ****************************************************/
240
+
241
+ CostCalculator CC(arr, n);
242
+ Matrix<float> D(nclusters, n);
243
+ Matrix<idx_t> T(nclusters, n);
244
+
245
+ for (idx_t m = 0; m < n; m++) {
246
+ D.at(0, m) = CC(0, m);
247
+ T.at(0, m) = 0;
248
+ }
249
+
250
+ std::vector<idx_t> indices(nclusters, 0);
251
+
252
+ for (idx_t k = 1; k < nclusters; ++k) {
253
+ // we define C here
254
+ auto C = [&D, &CC, &k](idx_t m, idx_t i) {
255
+ if (i == 0) {
256
+ return CC(i, m);
257
+ }
258
+ idx_t col = std::min(m, i - 1);
259
+ return D.at(k - 1, col) + CC(i, m);
260
+ };
261
+
262
+ std::vector<idx_t> argmins(n); // argmin of each row
263
+ smawk(n, n, C, argmins.data());
264
+ for (idx_t m = 0; m < argmins.size(); m++) {
265
+ idx_t idx = argmins[m];
266
+ D.at(k, m) = C(m, idx);
267
+ T.at(k, m) = idx;
268
+ }
269
+ }
270
+
271
+ /***************************************************
272
+ compute centroids by backtracking
273
+
274
+ T[K - 1][T[K][N] - 1] T[K][N] N
275
+ --------------|------------------------|-----------|
276
+ | cluster K - 1 | cluster K |
277
+
278
+ ****************************************************/
279
+
280
+ // for imbalance factor
281
+ double tot = 0.0, uf = 0.0;
282
+
283
+ idx_t end = n;
284
+ for (idx_t k = nclusters - 1; k >= 0; k--) {
285
+ idx_t start = T.at(k, end - 1);
286
+ float sum = std::accumulate(&arr[start], &arr[end], 0.0f);
287
+ idx_t size = end - start;
288
+ FAISS_THROW_IF_NOT_FMT(
289
+ size > 0, "Cluster %d: size %d", int(k), int(size));
290
+ centroids[k] = sum / size;
291
+ end = start;
292
+
293
+ tot += size;
294
+ uf += size * double(size);
295
+ }
296
+
297
+ uf = uf * nclusters / (tot * tot);
298
+ return uf;
299
+ }
300
+
301
+ } // namespace faiss