faiss 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -15,6 +15,7 @@
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/IDSelector.h>
18
19
  #include <faiss/impl/ResultHandler.h>
19
20
  #include <faiss/impl/platform_macros.h>
20
21
  #include <faiss/utils/AlignedTable.h>
@@ -137,6 +138,7 @@ struct FixedStorageHandler : SIMDResultHandler {
137
138
  }
138
139
  }
139
140
  }
141
+
140
142
  virtual ~FixedStorageHandler() {}
141
143
  };
142
144
 
@@ -150,8 +152,10 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat {
150
152
  int64_t i0 = 0; // query origin
151
153
  int64_t j0 = 0; // db origin
152
154
 
153
- ResultHandlerCompare(size_t nq, size_t ntotal)
154
- : SIMDResultHandlerToFloat(nq, ntotal) {
155
+ const IDSelector* sel;
156
+
157
+ ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in)
158
+ : SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} {
155
159
  this->is_CMax = C::is_max;
156
160
  this->sizeof_ids = sizeof(typename C::TI);
157
161
  this->with_fields = with_id_map;
@@ -232,9 +236,14 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
232
236
  float* dis;
233
237
  int64_t* ids;
234
238
 
235
- SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids)
236
- : RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) {
237
- for (int i = 0; i < nq; i++) {
239
+ SingleResultHandler(
240
+ size_t nq,
241
+ size_t ntotal,
242
+ float* dis,
243
+ int64_t* ids,
244
+ const IDSelector* sel_in)
245
+ : RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) {
246
+ for (size_t i = 0; i < nq; i++) {
238
247
  ids[i] = -1;
239
248
  idis[i] = C::neutral();
240
249
  }
@@ -256,20 +265,36 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
256
265
  d0.store(d32tab);
257
266
  d1.store(d32tab + 16);
258
267
 
259
- while (lt_mask) {
260
- // find first non-zero
261
- int j = __builtin_ctz(lt_mask);
262
- lt_mask -= 1 << j;
263
- T d = d32tab[j];
264
- if (C::cmp(idis[q], d)) {
265
- idis[q] = d;
266
- ids[q] = this->adjust_id(b, j);
268
+ if (this->sel != nullptr) {
269
+ while (lt_mask) {
270
+ // find first non-zero
271
+ int j = __builtin_ctz(lt_mask);
272
+ auto real_idx = this->adjust_id(b, j);
273
+ lt_mask -= 1 << j;
274
+ if (this->sel->is_member(real_idx)) {
275
+ T d = d32tab[j];
276
+ if (C::cmp(idis[q], d)) {
277
+ idis[q] = d;
278
+ ids[q] = real_idx;
279
+ }
280
+ }
281
+ }
282
+ } else {
283
+ while (lt_mask) {
284
+ // find first non-zero
285
+ int j = __builtin_ctz(lt_mask);
286
+ lt_mask -= 1 << j;
287
+ T d = d32tab[j];
288
+ if (C::cmp(idis[q], d)) {
289
+ idis[q] = d;
290
+ ids[q] = this->adjust_id(b, j);
291
+ }
267
292
  }
268
293
  }
269
294
  }
270
295
 
271
296
  void end() {
272
- for (int q = 0; q < this->nq; q++) {
297
+ for (size_t q = 0; q < this->nq; q++) {
273
298
  if (!normalizers) {
274
299
  dis[q] = idis[q];
275
300
  } else {
@@ -296,8 +321,14 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
296
321
 
297
322
  int64_t k; // number of results to keep
298
323
 
299
- HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids)
300
- : RHC(nq, ntotal),
324
+ HeapHandler(
325
+ size_t nq,
326
+ size_t ntotal,
327
+ int64_t k,
328
+ float* dis,
329
+ int64_t* ids,
330
+ const IDSelector* sel_in)
331
+ : RHC(nq, ntotal, sel_in),
301
332
  idis(nq * k),
302
333
  iids(nq * k),
303
334
  dis(dis),
@@ -330,21 +361,36 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
330
361
  d0.store(d32tab);
331
362
  d1.store(d32tab + 16);
332
363
 
333
- while (lt_mask) {
334
- // find first non-zero
335
- int j = __builtin_ctz(lt_mask);
336
- lt_mask -= 1 << j;
337
- T dis = d32tab[j];
338
- if (C::cmp(heap_dis[0], dis)) {
339
- int64_t idx = this->adjust_id(b, j);
340
- heap_pop<C>(k, heap_dis, heap_ids);
341
- heap_push<C>(k, heap_dis, heap_ids, dis, idx);
364
+ if (this->sel != nullptr) {
365
+ while (lt_mask) {
366
+ // find first non-zero
367
+ int j = __builtin_ctz(lt_mask);
368
+ auto real_idx = this->adjust_id(b, j);
369
+ lt_mask -= 1 << j;
370
+ if (this->sel->is_member(real_idx)) {
371
+ T dis = d32tab[j];
372
+ if (C::cmp(heap_dis[0], dis)) {
373
+ heap_replace_top<C>(
374
+ k, heap_dis, heap_ids, dis, real_idx);
375
+ }
376
+ }
377
+ }
378
+ } else {
379
+ while (lt_mask) {
380
+ // find first non-zero
381
+ int j = __builtin_ctz(lt_mask);
382
+ lt_mask -= 1 << j;
383
+ T dis = d32tab[j];
384
+ if (C::cmp(heap_dis[0], dis)) {
385
+ int64_t idx = this->adjust_id(b, j);
386
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
387
+ }
342
388
  }
343
389
  }
344
390
  }
345
391
 
346
392
  void end() override {
347
- for (int q = 0; q < this->nq; q++) {
393
+ for (size_t q = 0; q < this->nq; q++) {
348
394
  T* heap_dis_in = idis.data() + q * k;
349
395
  TI* heap_ids_in = iids.data() + q * k;
350
396
  heap_reorder<C>(k, heap_dis_in, heap_ids_in);
@@ -393,8 +439,12 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
393
439
  size_t k,
394
440
  size_t cap,
395
441
  float* dis,
396
- int64_t* ids)
397
- : RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) {
442
+ int64_t* ids,
443
+ const IDSelector* sel_in)
444
+ : RHC(nq, ntotal, sel_in),
445
+ capacity((cap + 15) & ~15),
446
+ dis(dis),
447
+ ids(ids) {
398
448
  assert(capacity % 16 == 0);
399
449
  all_ids.resize(nq * capacity);
400
450
  all_vals.resize(nq * capacity);
@@ -423,12 +473,25 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
423
473
  d0.store(d32tab);
424
474
  d1.store(d32tab + 16);
425
475
 
426
- while (lt_mask) {
427
- // find first non-zero
428
- int j = __builtin_ctz(lt_mask);
429
- lt_mask -= 1 << j;
430
- T dis = d32tab[j];
431
- res.add(dis, this->adjust_id(b, j));
476
+ if (this->sel != nullptr) {
477
+ while (lt_mask) {
478
+ // find first non-zero
479
+ int j = __builtin_ctz(lt_mask);
480
+ auto real_idx = this->adjust_id(b, j);
481
+ lt_mask -= 1 << j;
482
+ if (this->sel->is_member(real_idx)) {
483
+ T dis = d32tab[j];
484
+ res.add(dis, real_idx);
485
+ }
486
+ }
487
+ } else {
488
+ while (lt_mask) {
489
+ // find first non-zero
490
+ int j = __builtin_ctz(lt_mask);
491
+ lt_mask -= 1 << j;
492
+ T dis = d32tab[j];
493
+ res.add(dis, this->adjust_id(b, j));
494
+ }
432
495
  }
433
496
  }
434
497
 
@@ -439,7 +502,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
439
502
  CMin<float, int64_t>>::type;
440
503
 
441
504
  std::vector<int> perm(reservoirs[0].n);
442
- for (int q = 0; q < reservoirs.size(); q++) {
505
+ for (size_t q = 0; q < reservoirs.size(); q++) {
443
506
  ReservoirTopN<C>& res = reservoirs[q];
444
507
  size_t n = res.n;
445
508
 
@@ -454,14 +517,14 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
454
517
  one_a = 1 / normalizers[2 * q];
455
518
  b = normalizers[2 * q + 1];
456
519
  }
457
- for (int i = 0; i < res.i; i++) {
520
+ for (size_t i = 0; i < res.i; i++) {
458
521
  perm[i] = i;
459
522
  }
460
523
  // indirect sort of result arrays
461
524
  std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
462
525
  return C::cmp(res.vals[j], res.vals[i]);
463
526
  });
464
- for (int i = 0; i < res.i; i++) {
527
+ for (size_t i = 0; i < res.i; i++) {
465
528
  heap_dis[i] = res.vals[perm[i]] * one_a + b;
466
529
  heap_ids[i] = res.ids[perm[i]];
467
530
  }
@@ -472,7 +535,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
472
535
  }
473
536
  };
474
537
 
475
- /** Result hanlder for range search. The difficulty is that the range distances
538
+ /** Result handler for range search. The difficulty is that the range distances
476
539
  * have to be scaled using the scaler.
477
540
  */
478
541
 
@@ -499,13 +562,17 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
499
562
  };
500
563
  std::vector<Triplet> triplets;
501
564
 
502
- RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal)
503
- : RHC(rres.nq, ntotal), rres(rres), radius(radius) {
565
+ RangeHandler(
566
+ RangeSearchResult& rres,
567
+ float radius,
568
+ size_t ntotal,
569
+ const IDSelector* sel_in)
570
+ : RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) {
504
571
  thresholds.resize(nq);
505
572
  n_per_query.resize(nq + 1);
506
573
  }
507
574
 
508
- virtual void begin(const float* norms) {
575
+ virtual void begin(const float* norms) override {
509
576
  normalizers = norms;
510
577
  for (int q = 0; q < nq; ++q) {
511
578
  thresholds[q] =
@@ -528,13 +595,28 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
528
595
  d0.store(d32tab);
529
596
  d1.store(d32tab + 16);
530
597
 
531
- while (lt_mask) {
532
- // find first non-zero
533
- int j = __builtin_ctz(lt_mask);
534
- lt_mask -= 1 << j;
535
- T dis = d32tab[j];
536
- n_per_query[q]++;
537
- triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
598
+ if (this->sel != nullptr) {
599
+ while (lt_mask) {
600
+ // find first non-zero
601
+ int j = __builtin_ctz(lt_mask);
602
+ lt_mask -= 1 << j;
603
+
604
+ auto real_idx = this->adjust_id(b, j);
605
+ if (this->sel->is_member(real_idx)) {
606
+ T dis = d32tab[j];
607
+ n_per_query[q]++;
608
+ triplets.push_back({idx_t(q + q0), real_idx, dis});
609
+ }
610
+ }
611
+ } else {
612
+ while (lt_mask) {
613
+ // find first non-zero
614
+ int j = __builtin_ctz(lt_mask);
615
+ lt_mask -= 1 << j;
616
+ T dis = d32tab[j];
617
+ n_per_query[q]++;
618
+ triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
619
+ }
538
620
  }
539
621
  }
540
622
 
@@ -578,8 +660,9 @@ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
578
660
  float radius,
579
661
  size_t ntotal,
580
662
  size_t q0,
581
- size_t q1)
582
- : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal),
663
+ size_t q1,
664
+ const IDSelector* sel_in)
665
+ : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal, sel_in),
583
666
  pres(pres) {
584
667
  nq = q1 - q0;
585
668
  this->q0 = q0;
@@ -630,7 +713,7 @@ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
630
713
  */
631
714
 
632
715
  template <class C, bool W, class Consumer, class... Types>
633
- void dispatch_SIMDResultHanlder_fixedCW(
716
+ void dispatch_SIMDResultHandler_fixedCW(
634
717
  SIMDResultHandler& res,
635
718
  Consumer& consumer,
636
719
  Types... args) {
@@ -650,19 +733,19 @@ void dispatch_SIMDResultHanlder_fixedCW(
650
733
  }
651
734
 
652
735
  template <class C, class Consumer, class... Types>
653
- void dispatch_SIMDResultHanlder_fixedC(
736
+ void dispatch_SIMDResultHandler_fixedC(
654
737
  SIMDResultHandler& res,
655
738
  Consumer& consumer,
656
739
  Types... args) {
657
740
  if (res.with_fields) {
658
- dispatch_SIMDResultHanlder_fixedCW<C, true>(res, consumer, args...);
741
+ dispatch_SIMDResultHandler_fixedCW<C, true>(res, consumer, args...);
659
742
  } else {
660
- dispatch_SIMDResultHanlder_fixedCW<C, false>(res, consumer, args...);
743
+ dispatch_SIMDResultHandler_fixedCW<C, false>(res, consumer, args...);
661
744
  }
662
745
  }
663
746
 
664
747
  template <class Consumer, class... Types>
665
- void dispatch_SIMDResultHanlder(
748
+ void dispatch_SIMDResultHandler(
666
749
  SIMDResultHandler& res,
667
750
  Consumer& consumer,
668
751
  Types... args) {
@@ -680,24 +763,25 @@ void dispatch_SIMDResultHanlder(
680
763
  }
681
764
  } else if (res.sizeof_ids == sizeof(int)) {
682
765
  if (res.is_CMax) {
683
- dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int>>(
766
+ dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int>>(
684
767
  res, consumer, args...);
685
768
  } else {
686
- dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int>>(
769
+ dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int>>(
687
770
  res, consumer, args...);
688
771
  }
689
772
  } else if (res.sizeof_ids == sizeof(int64_t)) {
690
773
  if (res.is_CMax) {
691
- dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int64_t>>(
774
+ dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int64_t>>(
692
775
  res, consumer, args...);
693
776
  } else {
694
- dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int64_t>>(
777
+ dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int64_t>>(
695
778
  res, consumer, args...);
696
779
  }
697
780
  } else {
698
781
  FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
699
782
  }
700
783
  }
784
+
701
785
  } // namespace simd_result_handlers
702
786
 
703
787
  } // namespace faiss
@@ -140,8 +140,12 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
140
140
  {"SQ4", ScalarQuantizer::QT_4bit},
141
141
  {"SQ6", ScalarQuantizer::QT_6bit},
142
142
  {"SQfp16", ScalarQuantizer::QT_fp16},
143
+ {"SQbf16", ScalarQuantizer::QT_bf16},
144
+ {"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed},
145
+ {"SQ8_direct", ScalarQuantizer::QT_8bit_direct},
143
146
  };
144
- const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
147
+ const std::string sq_pattern =
148
+ "(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";
145
149
 
146
150
  std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
147
151
  {"_Nfloat", AdditiveQuantizer::ST_norm_float},
@@ -222,6 +226,19 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
222
226
  * Parse IndexIVF
223
227
  */
224
228
 
229
+ size_t parse_nlist(std::string s) {
230
+ size_t multiplier = 1;
231
+ if (s.back() == 'k') {
232
+ s.pop_back();
233
+ multiplier = 1024;
234
+ }
235
+ if (s.back() == 'M') {
236
+ s.pop_back();
237
+ multiplier = 1024 * 1024;
238
+ }
239
+ return std::stoi(s) * multiplier;
240
+ }
241
+
225
242
  // parsing guard + function
226
243
  Index* parse_coarse_quantizer(
227
244
  const std::string& description,
@@ -236,8 +253,8 @@ Index* parse_coarse_quantizer(
236
253
  };
237
254
  use_2layer = false;
238
255
 
239
- if (match("IVF([0-9]+)")) {
240
- nlist = std::stoi(sm[1].str());
256
+ if (match("IVF([0-9]+[kM]?)")) {
257
+ nlist = parse_nlist(sm[1].str());
241
258
  return new IndexFlat(d, mt);
242
259
  }
243
260
  if (match("IMI2x([0-9]+)")) {
@@ -248,18 +265,18 @@ Index* parse_coarse_quantizer(
248
265
  nlist = (size_t)1 << (2 * nbit);
249
266
  return new MultiIndexQuantizer(d, 2, nbit);
250
267
  }
251
- if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
252
- nlist = std::stoi(sm[1].str());
268
+ if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
269
+ nlist = parse_nlist(sm[1].str());
253
270
  int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
254
271
  return new IndexHNSWFlat(d, hnsw_M, mt);
255
272
  }
256
- if (match("IVF([0-9]+)_NSG([0-9]+)")) {
257
- nlist = std::stoi(sm[1].str());
273
+ if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
274
+ nlist = parse_nlist(sm[1].str());
258
275
  int R = std::stoi(sm[2]);
259
276
  return new IndexNSGFlat(d, R, mt);
260
277
  }
261
- if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
262
- nlist = std::stoi(sm[1].str());
278
+ if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
279
+ nlist = parse_nlist(sm[1].str());
263
280
  int no = std::stoi(sm[2].str());
264
281
  FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
265
282
  return parenthesis_indexes[no].release();
@@ -526,11 +543,12 @@ Index* parse_other_indexes(
526
543
  }
527
544
 
528
545
  // IndexLSH
529
- if (match("LSH(r?)(t?)")) {
530
- bool rotate_data = sm[1].length() > 0;
531
- bool train_thresholds = sm[2].length() > 0;
546
+ if (match("LSH([0-9]*)(r?)(t?)")) {
547
+ int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
548
+ bool rotate_data = sm[2].length() > 0;
549
+ bool train_thresholds = sm[3].length() > 0;
532
550
  FAISS_THROW_IF_NOT(metric == METRIC_L2);
533
- return new IndexLSH(d, d, rotate_data, train_thresholds);
551
+ return new IndexLSH(d, nbits, rotate_data, train_thresholds);
534
552
  }
535
553
 
536
554
  // IndexLattice
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // I/O code for indexes
11
9
 
12
10
  #ifndef FAISS_INDEX_IO_H
@@ -35,9 +33,12 @@ struct IOReader;
35
33
  struct IOWriter;
36
34
  struct InvertedLists;
37
35
 
38
- void write_index(const Index* idx, const char* fname);
39
- void write_index(const Index* idx, FILE* f);
40
- void write_index(const Index* idx, IOWriter* writer);
36
+ /// skip the storage for graph-based indexes
37
+ const int IO_FLAG_SKIP_STORAGE = 1;
38
+
39
+ void write_index(const Index* idx, const char* fname, int io_flags = 0);
40
+ void write_index(const Index* idx, FILE* f, int io_flags = 0);
41
+ void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);
41
42
 
42
43
  void write_index_binary(const IndexBinary* idx, const char* fname);
43
44
  void write_index_binary(const IndexBinary* idx, FILE* f);
@@ -52,6 +53,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4;
52
53
  const int IO_FLAG_SKIP_IVF_DATA = 8;
53
54
  // don't initialize precomputed table after loading
54
55
  const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16;
56
+ // don't compute the sdc table for PQ-based indices
57
+ // this will prevent distances from being computed
58
+ // between elements in the index. For indices like HNSWPQ,
59
+ // this will prevent graph building because sdc
60
+ // computations are required to construct the graph
61
+ const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
55
62
  // try to memmap data (useful to load an ArrayInvertedLists as an
56
63
  // OnDiskInvertedLists)
57
64
  const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/impl/CodePacker.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
+ #include <faiss/impl/IDSelector.h>
12
13
 
13
14
  #include <faiss/impl/io.h>
14
15
  #include <faiss/impl/io_macros.h>
@@ -54,7 +55,9 @@ size_t BlockInvertedLists::add_entries(
54
55
  codes[list_no].resize(n_block * block_size);
55
56
  if (o % block_size == 0) {
56
57
  // copy whole blocks
57
- memcpy(&codes[list_no][o * code_size], code, n_block * block_size);
58
+ memcpy(&codes[list_no][o * packer->code_size],
59
+ code,
60
+ n_block * block_size);
58
61
  } else {
59
62
  FAISS_THROW_IF_NOT_MSG(packer, "missing code packer");
60
63
  std::vector<uint8_t> buffer(packer->code_size);
@@ -76,6 +79,29 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
76
79
  return codes[list_no].get();
77
80
  }
78
81
 
82
+ size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
83
+ idx_t nremove = 0;
84
+ #pragma omp parallel for
85
+ for (idx_t i = 0; i < nlist; i++) {
86
+ std::vector<uint8_t> buffer(packer->code_size);
87
+ idx_t l = ids[i].size(), j = 0;
88
+ while (j < l) {
89
+ if (sel.is_member(ids[i][j])) {
90
+ l--;
91
+ ids[i][j] = ids[i][l];
92
+ packer->unpack_1(codes[i].data(), l, buffer.data());
93
+ packer->pack_1(buffer.data(), j, codes[i].data());
94
+ } else {
95
+ j++;
96
+ }
97
+ }
98
+ resize(i, l);
99
+ nremove += ids[i].size() - l;
100
+ }
101
+
102
+ return nremove;
103
+ }
104
+
79
105
  const idx_t* BlockInvertedLists::get_ids(size_t list_no) const {
80
106
  assert(list_no < nlist);
81
107
  return ids[list_no].data();
@@ -101,13 +127,7 @@ void BlockInvertedLists::update_entries(
101
127
  size_t,
102
128
  const idx_t*,
103
129
  const uint8_t*) {
104
- FAISS_THROW_MSG("not impemented");
105
- /*
106
- assert (list_no < nlist);
107
- assert (n_entry + offset <= ids[list_no].size());
108
- memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
109
- memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
110
- */
130
+ FAISS_THROW_MSG("not implemented");
111
131
  }
112
132
 
113
133
  BlockInvertedLists::~BlockInvertedLists() {
@@ -15,6 +15,7 @@
15
15
  namespace faiss {
16
16
 
17
17
  struct CodePacker;
18
+ struct IDSelector;
18
19
 
19
20
  /** Inverted Lists that are organized by blocks.
20
21
  *
@@ -47,6 +48,8 @@ struct BlockInvertedLists : InvertedLists {
47
48
  size_t list_size(size_t list_no) const override;
48
49
  const uint8_t* get_codes(size_t list_no) const override;
49
50
  const idx_t* get_ids(size_t list_no) const override;
51
+ /// remove ids from the InvertedLists
52
+ size_t remove_ids(const IDSelector& sel);
50
53
 
51
54
  // works only on empty BlockInvertedLists
52
55
  // the codes should be of size ceil(n_entry / n_per_block) * block_size
@@ -15,6 +15,7 @@
15
15
  #include <faiss/impl/AuxIndexStructures.h>
16
16
  #include <faiss/impl/FaissAssert.h>
17
17
  #include <faiss/impl/IDSelector.h>
18
+ #include <faiss/invlists/BlockInvertedLists.h>
18
19
 
19
20
  namespace faiss {
20
21
 
@@ -148,8 +149,12 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
148
149
  std::vector<idx_t> toremove(nlist);
149
150
 
150
151
  size_t nremove = 0;
151
-
152
+ BlockInvertedLists* block_invlists =
153
+ dynamic_cast<BlockInvertedLists*>(invlists);
152
154
  if (type == NoMap) {
155
+ if (block_invlists != nullptr) {
156
+ return block_invlists->remove_ids(sel);
157
+ }
153
158
  // exhaustive scan of IVF
154
159
  #pragma omp parallel for
155
160
  for (idx_t i = 0; i < nlist; i++) {
@@ -178,6 +183,9 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
178
183
  }
179
184
  }
180
185
  } else if (type == Hashtable) {
186
+ FAISS_THROW_IF_MSG(
187
+ block_invlists,
188
+ "remove with hashtable is not supported with BlockInvertedLists");
181
189
  const IDSelectorArray* sela =
182
190
  dynamic_cast<const IDSelectorArray*>(&sel);
183
191
  FAISS_THROW_IF_NOT_MSG(