faiss 0.3.1 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -15,6 +15,7 @@
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/IDSelector.h>
18
19
  #include <faiss/impl/ResultHandler.h>
19
20
  #include <faiss/impl/platform_macros.h>
20
21
  #include <faiss/utils/AlignedTable.h>
@@ -137,6 +138,7 @@ struct FixedStorageHandler : SIMDResultHandler {
137
138
  }
138
139
  }
139
140
  }
141
+
140
142
  virtual ~FixedStorageHandler() {}
141
143
  };
142
144
 
@@ -150,8 +152,10 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat {
150
152
  int64_t i0 = 0; // query origin
151
153
  int64_t j0 = 0; // db origin
152
154
 
153
- ResultHandlerCompare(size_t nq, size_t ntotal)
154
- : SIMDResultHandlerToFloat(nq, ntotal) {
155
+ const IDSelector* sel;
156
+
157
+ ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in)
158
+ : SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} {
155
159
  this->is_CMax = C::is_max;
156
160
  this->sizeof_ids = sizeof(typename C::TI);
157
161
  this->with_fields = with_id_map;
@@ -232,9 +236,14 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
232
236
  float* dis;
233
237
  int64_t* ids;
234
238
 
235
- SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids)
236
- : RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) {
237
- for (int i = 0; i < nq; i++) {
239
+ SingleResultHandler(
240
+ size_t nq,
241
+ size_t ntotal,
242
+ float* dis,
243
+ int64_t* ids,
244
+ const IDSelector* sel_in)
245
+ : RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) {
246
+ for (size_t i = 0; i < nq; i++) {
238
247
  ids[i] = -1;
239
248
  idis[i] = C::neutral();
240
249
  }
@@ -256,20 +265,36 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
256
265
  d0.store(d32tab);
257
266
  d1.store(d32tab + 16);
258
267
 
259
- while (lt_mask) {
260
- // find first non-zero
261
- int j = __builtin_ctz(lt_mask);
262
- lt_mask -= 1 << j;
263
- T d = d32tab[j];
264
- if (C::cmp(idis[q], d)) {
265
- idis[q] = d;
266
- ids[q] = this->adjust_id(b, j);
268
+ if (this->sel != nullptr) {
269
+ while (lt_mask) {
270
+ // find first non-zero
271
+ int j = __builtin_ctz(lt_mask);
272
+ auto real_idx = this->adjust_id(b, j);
273
+ lt_mask -= 1 << j;
274
+ if (this->sel->is_member(real_idx)) {
275
+ T d = d32tab[j];
276
+ if (C::cmp(idis[q], d)) {
277
+ idis[q] = d;
278
+ ids[q] = real_idx;
279
+ }
280
+ }
281
+ }
282
+ } else {
283
+ while (lt_mask) {
284
+ // find first non-zero
285
+ int j = __builtin_ctz(lt_mask);
286
+ lt_mask -= 1 << j;
287
+ T d = d32tab[j];
288
+ if (C::cmp(idis[q], d)) {
289
+ idis[q] = d;
290
+ ids[q] = this->adjust_id(b, j);
291
+ }
267
292
  }
268
293
  }
269
294
  }
270
295
 
271
296
  void end() {
272
- for (int q = 0; q < this->nq; q++) {
297
+ for (size_t q = 0; q < this->nq; q++) {
273
298
  if (!normalizers) {
274
299
  dis[q] = idis[q];
275
300
  } else {
@@ -296,8 +321,14 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
296
321
 
297
322
  int64_t k; // number of results to keep
298
323
 
299
- HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids)
300
- : RHC(nq, ntotal),
324
+ HeapHandler(
325
+ size_t nq,
326
+ size_t ntotal,
327
+ int64_t k,
328
+ float* dis,
329
+ int64_t* ids,
330
+ const IDSelector* sel_in)
331
+ : RHC(nq, ntotal, sel_in),
301
332
  idis(nq * k),
302
333
  iids(nq * k),
303
334
  dis(dis),
@@ -330,21 +361,36 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
330
361
  d0.store(d32tab);
331
362
  d1.store(d32tab + 16);
332
363
 
333
- while (lt_mask) {
334
- // find first non-zero
335
- int j = __builtin_ctz(lt_mask);
336
- lt_mask -= 1 << j;
337
- T dis = d32tab[j];
338
- if (C::cmp(heap_dis[0], dis)) {
339
- int64_t idx = this->adjust_id(b, j);
340
- heap_pop<C>(k, heap_dis, heap_ids);
341
- heap_push<C>(k, heap_dis, heap_ids, dis, idx);
364
+ if (this->sel != nullptr) {
365
+ while (lt_mask) {
366
+ // find first non-zero
367
+ int j = __builtin_ctz(lt_mask);
368
+ auto real_idx = this->adjust_id(b, j);
369
+ lt_mask -= 1 << j;
370
+ if (this->sel->is_member(real_idx)) {
371
+ T dis = d32tab[j];
372
+ if (C::cmp(heap_dis[0], dis)) {
373
+ heap_replace_top<C>(
374
+ k, heap_dis, heap_ids, dis, real_idx);
375
+ }
376
+ }
377
+ }
378
+ } else {
379
+ while (lt_mask) {
380
+ // find first non-zero
381
+ int j = __builtin_ctz(lt_mask);
382
+ lt_mask -= 1 << j;
383
+ T dis = d32tab[j];
384
+ if (C::cmp(heap_dis[0], dis)) {
385
+ int64_t idx = this->adjust_id(b, j);
386
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
387
+ }
342
388
  }
343
389
  }
344
390
  }
345
391
 
346
392
  void end() override {
347
- for (int q = 0; q < this->nq; q++) {
393
+ for (size_t q = 0; q < this->nq; q++) {
348
394
  T* heap_dis_in = idis.data() + q * k;
349
395
  TI* heap_ids_in = iids.data() + q * k;
350
396
  heap_reorder<C>(k, heap_dis_in, heap_ids_in);
@@ -393,8 +439,12 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
393
439
  size_t k,
394
440
  size_t cap,
395
441
  float* dis,
396
- int64_t* ids)
397
- : RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) {
442
+ int64_t* ids,
443
+ const IDSelector* sel_in)
444
+ : RHC(nq, ntotal, sel_in),
445
+ capacity((cap + 15) & ~15),
446
+ dis(dis),
447
+ ids(ids) {
398
448
  assert(capacity % 16 == 0);
399
449
  all_ids.resize(nq * capacity);
400
450
  all_vals.resize(nq * capacity);
@@ -423,12 +473,25 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
423
473
  d0.store(d32tab);
424
474
  d1.store(d32tab + 16);
425
475
 
426
- while (lt_mask) {
427
- // find first non-zero
428
- int j = __builtin_ctz(lt_mask);
429
- lt_mask -= 1 << j;
430
- T dis = d32tab[j];
431
- res.add(dis, this->adjust_id(b, j));
476
+ if (this->sel != nullptr) {
477
+ while (lt_mask) {
478
+ // find first non-zero
479
+ int j = __builtin_ctz(lt_mask);
480
+ auto real_idx = this->adjust_id(b, j);
481
+ lt_mask -= 1 << j;
482
+ if (this->sel->is_member(real_idx)) {
483
+ T dis = d32tab[j];
484
+ res.add(dis, real_idx);
485
+ }
486
+ }
487
+ } else {
488
+ while (lt_mask) {
489
+ // find first non-zero
490
+ int j = __builtin_ctz(lt_mask);
491
+ lt_mask -= 1 << j;
492
+ T dis = d32tab[j];
493
+ res.add(dis, this->adjust_id(b, j));
494
+ }
432
495
  }
433
496
  }
434
497
 
@@ -439,7 +502,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
439
502
  CMin<float, int64_t>>::type;
440
503
 
441
504
  std::vector<int> perm(reservoirs[0].n);
442
- for (int q = 0; q < reservoirs.size(); q++) {
505
+ for (size_t q = 0; q < reservoirs.size(); q++) {
443
506
  ReservoirTopN<C>& res = reservoirs[q];
444
507
  size_t n = res.n;
445
508
 
@@ -454,14 +517,14 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
454
517
  one_a = 1 / normalizers[2 * q];
455
518
  b = normalizers[2 * q + 1];
456
519
  }
457
- for (int i = 0; i < res.i; i++) {
520
+ for (size_t i = 0; i < res.i; i++) {
458
521
  perm[i] = i;
459
522
  }
460
523
  // indirect sort of result arrays
461
524
  std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
462
525
  return C::cmp(res.vals[j], res.vals[i]);
463
526
  });
464
- for (int i = 0; i < res.i; i++) {
527
+ for (size_t i = 0; i < res.i; i++) {
465
528
  heap_dis[i] = res.vals[perm[i]] * one_a + b;
466
529
  heap_ids[i] = res.ids[perm[i]];
467
530
  }
@@ -472,7 +535,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
472
535
  }
473
536
  };
474
537
 
475
- /** Result hanlder for range search. The difficulty is that the range distances
538
+ /** Result handler for range search. The difficulty is that the range distances
476
539
  * have to be scaled using the scaler.
477
540
  */
478
541
 
@@ -499,13 +562,17 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
499
562
  };
500
563
  std::vector<Triplet> triplets;
501
564
 
502
- RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal)
503
- : RHC(rres.nq, ntotal), rres(rres), radius(radius) {
565
+ RangeHandler(
566
+ RangeSearchResult& rres,
567
+ float radius,
568
+ size_t ntotal,
569
+ const IDSelector* sel_in)
570
+ : RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) {
504
571
  thresholds.resize(nq);
505
572
  n_per_query.resize(nq + 1);
506
573
  }
507
574
 
508
- virtual void begin(const float* norms) {
575
+ virtual void begin(const float* norms) override {
509
576
  normalizers = norms;
510
577
  for (int q = 0; q < nq; ++q) {
511
578
  thresholds[q] =
@@ -528,13 +595,28 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
528
595
  d0.store(d32tab);
529
596
  d1.store(d32tab + 16);
530
597
 
531
- while (lt_mask) {
532
- // find first non-zero
533
- int j = __builtin_ctz(lt_mask);
534
- lt_mask -= 1 << j;
535
- T dis = d32tab[j];
536
- n_per_query[q]++;
537
- triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
598
+ if (this->sel != nullptr) {
599
+ while (lt_mask) {
600
+ // find first non-zero
601
+ int j = __builtin_ctz(lt_mask);
602
+ lt_mask -= 1 << j;
603
+
604
+ auto real_idx = this->adjust_id(b, j);
605
+ if (this->sel->is_member(real_idx)) {
606
+ T dis = d32tab[j];
607
+ n_per_query[q]++;
608
+ triplets.push_back({idx_t(q + q0), real_idx, dis});
609
+ }
610
+ }
611
+ } else {
612
+ while (lt_mask) {
613
+ // find first non-zero
614
+ int j = __builtin_ctz(lt_mask);
615
+ lt_mask -= 1 << j;
616
+ T dis = d32tab[j];
617
+ n_per_query[q]++;
618
+ triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
619
+ }
538
620
  }
539
621
  }
540
622
 
@@ -578,8 +660,9 @@ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
578
660
  float radius,
579
661
  size_t ntotal,
580
662
  size_t q0,
581
- size_t q1)
582
- : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal),
663
+ size_t q1,
664
+ const IDSelector* sel_in)
665
+ : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal, sel_in),
583
666
  pres(pres) {
584
667
  nq = q1 - q0;
585
668
  this->q0 = q0;
@@ -630,7 +713,7 @@ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
630
713
  */
631
714
 
632
715
  template <class C, bool W, class Consumer, class... Types>
633
- void dispatch_SIMDResultHanlder_fixedCW(
716
+ void dispatch_SIMDResultHandler_fixedCW(
634
717
  SIMDResultHandler& res,
635
718
  Consumer& consumer,
636
719
  Types... args) {
@@ -650,19 +733,19 @@ void dispatch_SIMDResultHanlder_fixedCW(
650
733
  }
651
734
 
652
735
  template <class C, class Consumer, class... Types>
653
- void dispatch_SIMDResultHanlder_fixedC(
736
+ void dispatch_SIMDResultHandler_fixedC(
654
737
  SIMDResultHandler& res,
655
738
  Consumer& consumer,
656
739
  Types... args) {
657
740
  if (res.with_fields) {
658
- dispatch_SIMDResultHanlder_fixedCW<C, true>(res, consumer, args...);
741
+ dispatch_SIMDResultHandler_fixedCW<C, true>(res, consumer, args...);
659
742
  } else {
660
- dispatch_SIMDResultHanlder_fixedCW<C, false>(res, consumer, args...);
743
+ dispatch_SIMDResultHandler_fixedCW<C, false>(res, consumer, args...);
661
744
  }
662
745
  }
663
746
 
664
747
  template <class Consumer, class... Types>
665
- void dispatch_SIMDResultHanlder(
748
+ void dispatch_SIMDResultHandler(
666
749
  SIMDResultHandler& res,
667
750
  Consumer& consumer,
668
751
  Types... args) {
@@ -680,24 +763,25 @@ void dispatch_SIMDResultHanlder(
680
763
  }
681
764
  } else if (res.sizeof_ids == sizeof(int)) {
682
765
  if (res.is_CMax) {
683
- dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int>>(
766
+ dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int>>(
684
767
  res, consumer, args...);
685
768
  } else {
686
- dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int>>(
769
+ dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int>>(
687
770
  res, consumer, args...);
688
771
  }
689
772
  } else if (res.sizeof_ids == sizeof(int64_t)) {
690
773
  if (res.is_CMax) {
691
- dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int64_t>>(
774
+ dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int64_t>>(
692
775
  res, consumer, args...);
693
776
  } else {
694
- dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int64_t>>(
777
+ dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int64_t>>(
695
778
  res, consumer, args...);
696
779
  }
697
780
  } else {
698
781
  FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
699
782
  }
700
783
  }
784
+
701
785
  } // namespace simd_result_handlers
702
786
 
703
787
  } // namespace faiss
@@ -140,8 +140,12 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
140
140
  {"SQ4", ScalarQuantizer::QT_4bit},
141
141
  {"SQ6", ScalarQuantizer::QT_6bit},
142
142
  {"SQfp16", ScalarQuantizer::QT_fp16},
143
+ {"SQbf16", ScalarQuantizer::QT_bf16},
144
+ {"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed},
145
+ {"SQ8_direct", ScalarQuantizer::QT_8bit_direct},
143
146
  };
144
- const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
147
+ const std::string sq_pattern =
148
+ "(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";
145
149
 
146
150
  std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
147
151
  {"_Nfloat", AdditiveQuantizer::ST_norm_float},
@@ -222,6 +226,19 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
222
226
  * Parse IndexIVF
223
227
  */
224
228
 
229
+ size_t parse_nlist(std::string s) {
230
+ size_t multiplier = 1;
231
+ if (s.back() == 'k') {
232
+ s.pop_back();
233
+ multiplier = 1024;
234
+ }
235
+ if (s.back() == 'M') {
236
+ s.pop_back();
237
+ multiplier = 1024 * 1024;
238
+ }
239
+ return std::stoi(s) * multiplier;
240
+ }
241
+
225
242
  // parsing guard + function
226
243
  Index* parse_coarse_quantizer(
227
244
  const std::string& description,
@@ -236,8 +253,8 @@ Index* parse_coarse_quantizer(
236
253
  };
237
254
  use_2layer = false;
238
255
 
239
- if (match("IVF([0-9]+)")) {
240
- nlist = std::stoi(sm[1].str());
256
+ if (match("IVF([0-9]+[kM]?)")) {
257
+ nlist = parse_nlist(sm[1].str());
241
258
  return new IndexFlat(d, mt);
242
259
  }
243
260
  if (match("IMI2x([0-9]+)")) {
@@ -248,18 +265,18 @@ Index* parse_coarse_quantizer(
248
265
  nlist = (size_t)1 << (2 * nbit);
249
266
  return new MultiIndexQuantizer(d, 2, nbit);
250
267
  }
251
- if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
252
- nlist = std::stoi(sm[1].str());
268
+ if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
269
+ nlist = parse_nlist(sm[1].str());
253
270
  int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
254
271
  return new IndexHNSWFlat(d, hnsw_M, mt);
255
272
  }
256
- if (match("IVF([0-9]+)_NSG([0-9]+)")) {
257
- nlist = std::stoi(sm[1].str());
273
+ if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
274
+ nlist = parse_nlist(sm[1].str());
258
275
  int R = std::stoi(sm[2]);
259
276
  return new IndexNSGFlat(d, R, mt);
260
277
  }
261
- if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
262
- nlist = std::stoi(sm[1].str());
278
+ if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
279
+ nlist = parse_nlist(sm[1].str());
263
280
  int no = std::stoi(sm[2].str());
264
281
  FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
265
282
  return parenthesis_indexes[no].release();
@@ -526,11 +543,12 @@ Index* parse_other_indexes(
526
543
  }
527
544
 
528
545
  // IndexLSH
529
- if (match("LSH(r?)(t?)")) {
530
- bool rotate_data = sm[1].length() > 0;
531
- bool train_thresholds = sm[2].length() > 0;
546
+ if (match("LSH([0-9]*)(r?)(t?)")) {
547
+ int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
548
+ bool rotate_data = sm[2].length() > 0;
549
+ bool train_thresholds = sm[3].length() > 0;
532
550
  FAISS_THROW_IF_NOT(metric == METRIC_L2);
533
- return new IndexLSH(d, d, rotate_data, train_thresholds);
551
+ return new IndexLSH(d, nbits, rotate_data, train_thresholds);
534
552
  }
535
553
 
536
554
  // IndexLattice
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // I/O code for indexes
11
9
 
12
10
  #ifndef FAISS_INDEX_IO_H
@@ -35,9 +33,12 @@ struct IOReader;
35
33
  struct IOWriter;
36
34
  struct InvertedLists;
37
35
 
38
- void write_index(const Index* idx, const char* fname);
39
- void write_index(const Index* idx, FILE* f);
40
- void write_index(const Index* idx, IOWriter* writer);
36
+ /// skip the storage for graph-based indexes
37
+ const int IO_FLAG_SKIP_STORAGE = 1;
38
+
39
+ void write_index(const Index* idx, const char* fname, int io_flags = 0);
40
+ void write_index(const Index* idx, FILE* f, int io_flags = 0);
41
+ void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);
41
42
 
42
43
  void write_index_binary(const IndexBinary* idx, const char* fname);
43
44
  void write_index_binary(const IndexBinary* idx, FILE* f);
@@ -52,6 +53,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4;
52
53
  const int IO_FLAG_SKIP_IVF_DATA = 8;
53
54
  // don't initialize precomputed table after loading
54
55
  const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16;
56
+ // don't compute the sdc table for PQ-based indices
57
+ // this will prevent distances from being computed
58
+ // between elements in the index. For indices like HNSWPQ,
59
+ // this will prevent graph building because sdc
60
+ // computations are required to construct the graph
61
+ const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
55
62
  // try to memmap data (useful to load an ArrayInvertedLists as an
56
63
  // OnDiskInvertedLists)
57
64
  const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/impl/CodePacker.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
+ #include <faiss/impl/IDSelector.h>
12
13
 
13
14
  #include <faiss/impl/io.h>
14
15
  #include <faiss/impl/io_macros.h>
@@ -54,7 +55,9 @@ size_t BlockInvertedLists::add_entries(
54
55
  codes[list_no].resize(n_block * block_size);
55
56
  if (o % block_size == 0) {
56
57
  // copy whole blocks
57
- memcpy(&codes[list_no][o * code_size], code, n_block * block_size);
58
+ memcpy(&codes[list_no][o * packer->code_size],
59
+ code,
60
+ n_block * block_size);
58
61
  } else {
59
62
  FAISS_THROW_IF_NOT_MSG(packer, "missing code packer");
60
63
  std::vector<uint8_t> buffer(packer->code_size);
@@ -76,6 +79,29 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
76
79
  return codes[list_no].get();
77
80
  }
78
81
 
82
+ size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
83
+ idx_t nremove = 0;
84
+ #pragma omp parallel for
85
+ for (idx_t i = 0; i < nlist; i++) {
86
+ std::vector<uint8_t> buffer(packer->code_size);
87
+ idx_t l = ids[i].size(), j = 0;
88
+ while (j < l) {
89
+ if (sel.is_member(ids[i][j])) {
90
+ l--;
91
+ ids[i][j] = ids[i][l];
92
+ packer->unpack_1(codes[i].data(), l, buffer.data());
93
+ packer->pack_1(buffer.data(), j, codes[i].data());
94
+ } else {
95
+ j++;
96
+ }
97
+ }
98
+ resize(i, l);
99
+ nremove += ids[i].size() - l;
100
+ }
101
+
102
+ return nremove;
103
+ }
104
+
79
105
  const idx_t* BlockInvertedLists::get_ids(size_t list_no) const {
80
106
  assert(list_no < nlist);
81
107
  return ids[list_no].data();
@@ -101,13 +127,7 @@ void BlockInvertedLists::update_entries(
101
127
  size_t,
102
128
  const idx_t*,
103
129
  const uint8_t*) {
104
- FAISS_THROW_MSG("not impemented");
105
- /*
106
- assert (list_no < nlist);
107
- assert (n_entry + offset <= ids[list_no].size());
108
- memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
109
- memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
110
- */
130
+ FAISS_THROW_MSG("not implemented");
111
131
  }
112
132
 
113
133
  BlockInvertedLists::~BlockInvertedLists() {
@@ -15,6 +15,7 @@
15
15
  namespace faiss {
16
16
 
17
17
  struct CodePacker;
18
+ struct IDSelector;
18
19
 
19
20
  /** Inverted Lists that are organized by blocks.
20
21
  *
@@ -47,6 +48,8 @@ struct BlockInvertedLists : InvertedLists {
47
48
  size_t list_size(size_t list_no) const override;
48
49
  const uint8_t* get_codes(size_t list_no) const override;
49
50
  const idx_t* get_ids(size_t list_no) const override;
51
+ /// remove ids from the InvertedLists
52
+ size_t remove_ids(const IDSelector& sel);
50
53
 
51
54
  // works only on empty BlockInvertedLists
52
55
  // the codes should be of size ceil(n_entry / n_per_block) * block_size
@@ -15,6 +15,7 @@
15
15
  #include <faiss/impl/AuxIndexStructures.h>
16
16
  #include <faiss/impl/FaissAssert.h>
17
17
  #include <faiss/impl/IDSelector.h>
18
+ #include <faiss/invlists/BlockInvertedLists.h>
18
19
 
19
20
  namespace faiss {
20
21
 
@@ -148,8 +149,12 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
148
149
  std::vector<idx_t> toremove(nlist);
149
150
 
150
151
  size_t nremove = 0;
151
-
152
+ BlockInvertedLists* block_invlists =
153
+ dynamic_cast<BlockInvertedLists*>(invlists);
152
154
  if (type == NoMap) {
155
+ if (block_invlists != nullptr) {
156
+ return block_invlists->remove_ids(sel);
157
+ }
153
158
  // exhaustive scan of IVF
154
159
  #pragma omp parallel for
155
160
  for (idx_t i = 0; i < nlist; i++) {
@@ -178,6 +183,9 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
178
183
  }
179
184
  }
180
185
  } else if (type == Hashtable) {
186
+ FAISS_THROW_IF_MSG(
187
+ block_invlists,
188
+ "remove with hashtable is not supported with BlockInvertedLists");
181
189
  const IDSelectorArray* sela =
182
190
  dynamic_cast<const IDSelectorArray*>(&sel);
183
191
  FAISS_THROW_IF_NOT_MSG(