faiss 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +8 -0
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/faiss/IVFlib.cpp +25 -49
  7. data/vendor/faiss/faiss/Index.cpp +11 -0
  8. data/vendor/faiss/faiss/Index.h +24 -1
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  10. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFastScan.h +3 -8
  13. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  14. data/vendor/faiss/faiss/IndexFlat.h +80 -0
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
  16. data/vendor/faiss/faiss/IndexHNSW.h +57 -1
  17. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
  18. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
  19. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
  20. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
  21. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
  22. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  23. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  24. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  25. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
  26. data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
  27. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
  28. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
  29. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  30. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  31. data/vendor/faiss/faiss/clone_index.cpp +2 -0
  32. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  33. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
  34. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  35. data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
  36. data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
  37. data/vendor/faiss/faiss/impl/HNSW.h +31 -2
  38. data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
  39. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  40. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  41. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  42. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  43. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
  44. data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
  45. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
  46. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
  47. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  48. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
  50. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
  51. data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
  52. data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
  53. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  54. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  55. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +182 -15
  57. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  58. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
  61. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  62. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  63. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  64. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  65. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  66. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  67. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  68. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  69. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  70. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  71. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  72. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  73. data/vendor/faiss/faiss/utils/utils.cpp +4 -0
  74. metadata +18 -1
@@ -0,0 +1,67 @@
1
+ /*
2
+ * Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ /*
9
+ * Portions Copyright 2025 Intel Corporation
10
+ *
11
+ * Licensed under the Apache License, Version 2.0 (the "License");
12
+ * you may not use this file except in compliance with the License.
13
+ * You may obtain a copy of the License at
14
+ *
15
+ * http://www.apache.org/licenses/LICENSE-2.0
16
+ *
17
+ * Unless required by applicable law or agreed to in writing, software
18
+ * distributed under the License is distributed on an "AS IS" BASIS,
19
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ * See the License for the specific language governing permissions and
21
+ * limitations under the License.
22
+ */
23
+
24
+ #pragma once
25
+
26
+ #include <iostream>
27
+ #include <streambuf>
28
+ #include <vector>
29
+
30
+ #include <faiss/impl/io.h>
31
+
32
+ namespace faiss {
33
+ namespace svs_io {
34
+
35
+ // Bridges IOWriter to std::ostream for streaming serialization.
36
+ // No buffering concerns since consumer is expected to write everything
37
+ // he receives.
38
+ struct WriterStreambuf : std::streambuf {
39
+ IOWriter* w;
40
+ explicit WriterStreambuf(IOWriter* w_);
41
+ ~WriterStreambuf() override;
42
+
43
+ protected:
44
+ std::streamsize xsputn(const char* s, std::streamsize n) override;
45
+ int overflow(int ch) override;
46
+ };
47
+
48
+ // Bridges IOReader to std::istream for streaming deserialization.
49
+ // Uses minimal buffering (single byte) to avoid over-reading from IOReader,
50
+ // which would advance its position beyond what the stream consumer actually
51
+ // read. This ensures subsequent direct reads from IOReader continue at the
52
+ // correct position. Bulk reads via xsgetn() forward directly to IOReader
53
+ // without intermediate buffering.
54
+ struct ReaderStreambuf : std::streambuf {
55
+ IOReader* r;
56
+ char single_char_buffer; // Single-byte buffer for underflow() operations
57
+
58
+ explicit ReaderStreambuf(IOReader* rr);
59
+ ~ReaderStreambuf() override;
60
+
61
+ protected:
62
+ int_type underflow() override;
63
+ std::streamsize xsgetn(char* s, std::streamsize n) override;
64
+ };
65
+
66
+ } // namespace svs_io
67
+ } // namespace faiss
@@ -52,6 +52,13 @@
52
52
  #include <faiss/IndexBinaryHNSW.h>
53
53
  #include <faiss/IndexBinaryHash.h>
54
54
  #include <faiss/IndexBinaryIVF.h>
55
+
56
+ #ifdef FAISS_ENABLE_SVS
57
+ #include <faiss/svs/IndexSVSFlat.h>
58
+ #include <faiss/svs/IndexSVSVamana.h>
59
+ #include <faiss/svs/IndexSVSVamanaLVQ.h>
60
+ #include <faiss/svs/IndexSVSVamanaLeanVec.h>
61
+ #endif
55
62
  #include <faiss/IndexIDMap.h>
56
63
  #include <algorithm>
57
64
  #include <cctype>
@@ -193,8 +200,6 @@ std::vector<size_t> aq_parse_nbits(std::string stok) {
193
200
  return nbits;
194
201
  }
195
202
 
196
- const std::string rabitq_pattern = "(RaBitQ)";
197
-
198
203
  /***************************************************************
199
204
  * Parse VectorTransform
200
205
  */
@@ -457,12 +462,21 @@ IndexIVF* parse_IndexIVF(
457
462
  }
458
463
  return index_ivf;
459
464
  }
460
- if (match(rabitq_pattern)) {
461
- return new IndexIVFRaBitQ(get_q(), d, nlist, mt, own_il);
462
- }
463
- if (match("RaBitQfs(_[0-9]+)?")) {
464
- int bbs = mres_to_int(sm[1], 32, 1);
465
- return new IndexIVFRaBitQFastScan(get_q(), d, nlist, mt, bbs, own_il);
465
+ // IndexIVFRaBitQ with optional nb_bits (1-9)
466
+ // Accepts: "RaBitQ" (default 1-bit) or "RaBitQ{nb_bits}" (e.g., "RaBitQ4")
467
+ if (match("RaBitQ([1-9])?")) {
468
+ uint8_t nb_bits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : 1;
469
+ return new IndexIVFRaBitQ(get_q(), d, nlist, mt, own_il, nb_bits);
470
+ }
471
+ // Accepts: "RaBitQfs" (default 1-bit, batch size 32)
472
+ // "RaBitQfs{nb_bits}" (e.g., "RaBitQfs4")
473
+ // "RaBitQfs_64" (1-bit, batch size 64)
474
+ // "RaBitQfs{nb_bits}_{bbs}" (e.g., "RaBitQfs4_64")
475
+ if (match("RaBitQfs([1-9])?(_[0-9]+)?")) {
476
+ uint8_t nb_bits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : 1;
477
+ int bbs = mres_to_int(sm[2], 32, 1);
478
+ return new IndexIVFRaBitQFastScan(
479
+ get_q(), d, nlist, mt, bbs, own_il, nb_bits);
466
480
  }
467
481
  return nullptr;
468
482
  }
@@ -485,6 +499,11 @@ IndexHNSW* parse_IndexHNSW(
485
499
  return new IndexHNSWFlat(d, hnsw_M, mt);
486
500
  }
487
501
 
502
+ if (match("FlatPanorama([0-9]+)?")) {
503
+ int nlevels = mres_to_int(sm[1], 8); // default to 8 levels
504
+ return new IndexHNSWFlatPanorama(d, hnsw_M, nlevels, mt);
505
+ }
506
+
488
507
  if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
489
508
  int M = std::stoi(sm[1].str());
490
509
  int nbit = mres_to_int(sm[2], 8, 1);
@@ -551,6 +570,109 @@ IndexNSG* parse_IndexNSG(
551
570
  return nullptr;
552
571
  }
553
572
 
573
+ #ifdef FAISS_ENABLE_SVS
574
+ /***************************************************************
575
+ * Parse IndexSVS
576
+ */
577
+
578
+ SVSStorageKind parse_lvq(const std::string& lvq_string) {
579
+ if (lvq_string == "LVQ4x0") {
580
+ return SVSStorageKind::SVS_LVQ4x0;
581
+ }
582
+ if (lvq_string == "LVQ4x4") {
583
+ return SVSStorageKind::SVS_LVQ4x4;
584
+ }
585
+ if (lvq_string == "LVQ4x8") {
586
+ return SVSStorageKind::SVS_LVQ4x8;
587
+ }
588
+ FAISS_ASSERT(!"not supported SVS LVQ level");
589
+ }
590
+
591
+ SVSStorageKind parse_leanvec(const std::string& leanvec_string) {
592
+ if (leanvec_string == "LeanVec4x4") {
593
+ return SVSStorageKind::SVS_LeanVec4x4;
594
+ }
595
+ if (leanvec_string == "LeanVec4x8") {
596
+ return SVSStorageKind::SVS_LeanVec4x8;
597
+ }
598
+ if (leanvec_string == "LeanVec8x8") {
599
+ return SVSStorageKind::SVS_LeanVec8x8;
600
+ }
601
+ FAISS_ASSERT(!"not supported SVS Leanvec level");
602
+ }
603
+
604
+ Index* parse_svs_datatype(
605
+ const std::string& index_type,
606
+ const std::string& arg_string,
607
+ const std::string& datatype_string,
608
+ int d,
609
+ MetricType mt) {
610
+ std::smatch sm;
611
+
612
+ if (datatype_string.empty()) {
613
+ if (index_type == "Vamana")
614
+ return new IndexSVSVamana(d, std::stoul(arg_string), mt);
615
+ if (index_type == "Flat")
616
+ return new IndexSVSFlat(d, mt);
617
+ FAISS_ASSERT(!"Unspported SVS index type");
618
+ }
619
+ if (re_match(datatype_string, "FP16", sm)) {
620
+ if (index_type == "Vamana")
621
+ return new IndexSVSVamana(
622
+ d, std::stoul(arg_string), mt, SVSStorageKind::SVS_FP16);
623
+ FAISS_ASSERT(!"Unspported SVS index type for Float16");
624
+ }
625
+ if (re_match(datatype_string, "SQI8", sm)) {
626
+ if (index_type == "Vamana")
627
+ return new IndexSVSVamana(
628
+ d, std::stoul(arg_string), mt, SVSStorageKind::SVS_SQI8);
629
+ FAISS_ASSERT(!"Unspported SVS index type for SQI8");
630
+ }
631
+ if (re_match(datatype_string, "(LVQ[0-9]+x[0-9]+)", sm)) {
632
+ if (index_type == "Vamana")
633
+ return new IndexSVSVamanaLVQ(
634
+ d, std::stoul(arg_string), mt, parse_lvq(sm[0].str()));
635
+ FAISS_ASSERT(!"Unspported SVS index type for LVQ");
636
+ }
637
+ if (re_match(datatype_string, "(LeanVec[0-9]+x[0-9]+)(_[0-9]+)?", sm)) {
638
+ std::string leanvec_d_string =
639
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "0";
640
+ int leanvec_d = std::stoul(leanvec_d_string);
641
+
642
+ if (index_type == "Vamana")
643
+ return new IndexSVSVamanaLeanVec(
644
+ d,
645
+ std::stoul(arg_string),
646
+ mt,
647
+ leanvec_d,
648
+ parse_leanvec(sm[1].str()));
649
+ FAISS_ASSERT(!"Unspported SVS index type for LeanVec");
650
+ }
651
+ return nullptr;
652
+ }
653
+
654
+ Index* parse_IndexSVS(const std::string& code_string, int d, MetricType mt) {
655
+ std::smatch sm;
656
+ if (re_match(code_string, "Flat(,.+)?", sm)) {
657
+ std::string datatype_string =
658
+ sm[1].length() > 0 ? sm[1].str().substr(1) : "";
659
+ return parse_svs_datatype("Flat", "", datatype_string, d, mt);
660
+ }
661
+ if (re_match(code_string, "Vamana([0-9]+)(,.+)?", sm)) {
662
+ Index* index{nullptr};
663
+ std::string degree_string = sm[1].str();
664
+ std::string datatype_string =
665
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "";
666
+ return parse_svs_datatype(
667
+ "Vamana", degree_string, datatype_string, d, mt);
668
+ }
669
+ if (re_match(code_string, "IVF([0-9]+)(,.+)?", sm)) {
670
+ FAISS_ASSERT(!"Unspported SVS index type");
671
+ }
672
+ return nullptr;
673
+ }
674
+ #endif // FAISS_ENABLE_SVS
675
+
554
676
  /***************************************************************
555
677
  * Parse basic indexes
556
678
  */
@@ -569,6 +691,18 @@ Index* parse_other_indexes(
569
691
  return new IndexFlat(d, metric);
570
692
  }
571
693
 
694
+ // IndexFlatL2Panorama
695
+ if (match("FlatL2Panorama([0-9]+)(_[0-9]+)?")) {
696
+ FAISS_THROW_IF_NOT(metric == METRIC_L2);
697
+ int nlevels = std::stoi(sm[1].str());
698
+ if (sm[2].length() > 0) {
699
+ int batch_size = std::stoi(sm[2].str().substr(1));
700
+ return new IndexFlatL2Panorama(d, nlevels, (size_t)batch_size);
701
+ } else {
702
+ return new IndexFlatL2Panorama(d, nlevels);
703
+ }
704
+ }
705
+
572
706
  // IndexLSH
573
707
  if (match("LSH([0-9]*)(r?)(t?)")) {
574
708
  int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
@@ -685,15 +819,17 @@ Index* parse_other_indexes(
685
819
  }
686
820
  }
687
821
 
688
- // IndexRaBitQ
689
- if (match(rabitq_pattern)) {
690
- return new IndexRaBitQ(d, metric);
822
+ // IndexRaBitQ with optional nb_bits (1-9)
823
+ // Accepts: "RaBitQ" (default 1-bit) or "RaBitQ{nb_bits}" (e.g., "RaBitQ4")
824
+ if (match("RaBitQ([1-9])?")) {
825
+ uint8_t nb_bits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : 1;
826
+ return new IndexRaBitQ(d, metric, nb_bits);
691
827
  }
692
828
 
693
- // IndexRaBitQFastScan
694
- if (match("RaBitQfs(_[0-9]+)?")) {
695
- int bbs = mres_to_int(sm[1], 32, 1);
696
- return new IndexRaBitQFastScan(d, metric, bbs);
829
+ if (match("RaBitQfs([1-9])?(_[0-9]+)?")) {
830
+ uint8_t nb_bits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : 1;
831
+ int bbs = mres_to_int(sm[2], 32, 1);
832
+ return new IndexRaBitQFastScan(d, metric, bbs, nb_bits);
697
833
  }
698
834
 
699
835
  return nullptr;
@@ -736,6 +872,18 @@ std::unique_ptr<Index> index_factory_sub(
736
872
  return std::unique_ptr<Index>(idmap);
737
873
  }
738
874
 
875
+ // handle refine Panorama
876
+ // TODO(aknayar): Add tests to test_factory.py
877
+ if (re_match(description, "(.+),RefinePanorama\\((.+)\\)", sm)) {
878
+ std::unique_ptr<Index> filter_index =
879
+ index_factory_sub(d, sm[1].str(), metric);
880
+ std::unique_ptr<Index> refine_index =
881
+ index_factory_sub(d, sm[2].str(), metric);
882
+ auto* index_rf = new IndexRefinePanorama(
883
+ filter_index.release(), refine_index.release());
884
+ return std::unique_ptr<Index>(index_rf);
885
+ }
886
+
739
887
  // handle refines
740
888
  if (re_match(description, "(.+),RFlat", sm) ||
741
889
  re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
@@ -842,6 +990,25 @@ std::unique_ptr<Index> index_factory_sub(
842
990
  return std::unique_ptr<Index>(index);
843
991
  }
844
992
 
993
+ #ifdef FAISS_ENABLE_SVS
994
+ if (re_match(description, "SVS((?:Flat|Vamana|IVF).*)", sm)) {
995
+ std::string code_string = sm[1].str();
996
+ if (verbose) {
997
+ printf("parsing SVS string %s code_string=%s",
998
+ description.c_str(),
999
+ code_string.c_str());
1000
+ }
1001
+
1002
+ Index* index = parse_IndexSVS(code_string, d, metric);
1003
+ FAISS_THROW_IF_NOT_FMT(
1004
+ index,
1005
+ "could not parse SVS code description %s in %s",
1006
+ code_string.c_str(),
1007
+ description.c_str());
1008
+ return std::unique_ptr<Index>(index);
1009
+ }
1010
+ #endif // FAISS_ENABLE_SVS
1011
+
845
1012
  // NSG variants (it was unclear in the old version that the separator was a
846
1013
  // "," so we support both "_" and ",")
847
1014
  if (re_match(description, "NSG([0-9]*)([,_].*)?", sm)) {
@@ -20,7 +20,7 @@ struct IDSelector;
20
20
  /** Inverted Lists that are organized by blocks.
21
21
  *
22
22
  * Different from the regular inverted lists, the codes are organized by blocks
23
- * of size block_size bytes that reprsent a set of n_per_block. Therefore, code
23
+ * of size block_size bytes that represent a set of n_per_block. Therefore, code
24
24
  * allocations are always rounded up to block_size bytes. The codes are also
25
25
  * aligned on 32-byte boundaries for use with SIMD.
26
26
  *
@@ -53,7 +53,7 @@ void DirectMap::set_type(
53
53
  for (long ofs = 0; ofs < list_size; ofs++) {
54
54
  FAISS_THROW_IF_NOT_MSG(
55
55
  0 <= idlist[ofs] && idlist[ofs] < ntotal,
56
- "direct map supported only for seuquential ids");
56
+ "direct map supported only for sequential ids");
57
57
  array[idlist[ofs]] = lo_build(key, ofs);
58
58
  }
59
59
  } else if (new_type == Hashtable) {
@@ -229,7 +229,7 @@ bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context)
229
229
  }
230
230
  }
231
231
 
232
- // implemnent iterator on top of get_codes / get_ids
232
+ // implement iterator on top of get_codes / get_ids
233
233
  namespace {
234
234
 
235
235
  struct CodeArrayIterator : InvertedListsIterator {
@@ -358,7 +358,8 @@ ArrayInvertedListsPanorama::ArrayInvertedListsPanorama(
358
358
  n_levels(n_levels),
359
359
  level_width(
360
360
  (((code_size / sizeof(float)) + n_levels - 1) / n_levels) *
361
- sizeof(float)) {
361
+ sizeof(float)),
362
+ pano(code_size, n_levels, kBatchSize) {
362
363
  FAISS_THROW_IF_NOT(n_levels > 0);
363
364
  FAISS_THROW_IF_NOT(code_size % sizeof(float) == 0);
364
365
  FAISS_THROW_IF_NOT_MSG(
@@ -390,8 +391,11 @@ size_t ArrayInvertedListsPanorama::add_entries(
390
391
  codes[list_no].resize(num_batches * kBatchSize * code_size);
391
392
  cum_sums[list_no].resize(num_batches * kBatchSize * (n_levels + 1));
392
393
 
393
- copy_codes_to_level_layout(list_no, o, n_entry, code);
394
- compute_cumulative_sums(list_no, o, n_entry, code);
394
+ // Cast to float* is safe here as we guarantee codes are always float
395
+ // vectors for `IndexIVFFlatPanorama` (verified by the constructor).
396
+ const float* vectors = reinterpret_cast<const float*>(code);
397
+ pano.copy_codes_to_level_layout(codes[list_no].data(), o, n_entry, code);
398
+ pano.compute_cumulative_sums(cum_sums[list_no].data(), o, n_entry, vectors);
395
399
 
396
400
  return o;
397
401
  }
@@ -406,8 +410,14 @@ void ArrayInvertedListsPanorama::update_entries(
406
410
  assert(n_entry + offset <= ids[list_no].size());
407
411
 
408
412
  memcpy(&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
409
- copy_codes_to_level_layout(list_no, offset, n_entry, code);
410
- compute_cumulative_sums(list_no, offset, n_entry, code);
413
+
414
+ // Cast to float* is safe here as we guarantee codes are always float
415
+ // vectors for `IndexIVFFlatPanorama` (verified by the constructor).
416
+ const float* vectors = reinterpret_cast<const float*>(code);
417
+ pano.copy_codes_to_level_layout(
418
+ codes[list_no].data(), offset, n_entry, code);
419
+ pano.compute_cumulative_sums(
420
+ cum_sums[list_no].data(), offset, n_entry, vectors);
411
421
  }
412
422
 
413
423
  void ArrayInvertedListsPanorama::resize(size_t list_no, size_t new_size) {
@@ -426,21 +436,8 @@ const uint8_t* ArrayInvertedListsPanorama::get_single_code(
426
436
 
427
437
  uint8_t* recons_buffer = new uint8_t[code_size];
428
438
 
429
- const uint8_t* codes_base = codes[list_no].data();
430
-
431
- size_t batch_no = offset / kBatchSize;
432
- size_t pos_in_batch = offset % kBatchSize;
433
- size_t batch_offset = batch_no * kBatchSize * code_size;
434
-
435
- for (size_t level = 0; level < n_levels; level++) {
436
- size_t level_offset = level * level_width * kBatchSize;
437
- const uint8_t* src = codes_base + batch_offset + level_offset +
438
- pos_in_batch * level_width;
439
- uint8_t* dest = recons_buffer + level * level_width;
440
- size_t copy_size =
441
- std::min(level_width, code_size - level * level_width);
442
- memcpy(dest, src, copy_size);
443
- }
439
+ float* recons = reinterpret_cast<float*>(recons_buffer);
440
+ pano.reconstruct(offset, recons, codes[list_no].data());
444
441
 
445
442
  return recons_buffer;
446
443
  }
@@ -463,94 +460,6 @@ InvertedListsIterator* ArrayInvertedListsPanorama::get_iterator(
463
460
  return nullptr;
464
461
  }
465
462
 
466
- void ArrayInvertedListsPanorama::compute_cumulative_sums(
467
- size_t list_no,
468
- size_t offset,
469
- size_t n_entry,
470
- const uint8_t* code) {
471
- // Cast to float* is safe here as we guarantee codes are always float
472
- // vectors for `IndexIVFFlatPanorama` (verified by the constructor).
473
- const float* vectors = reinterpret_cast<const float*>(code);
474
- const size_t d = code_size / sizeof(float);
475
-
476
- std::vector<float> suffix_sums(d + 1);
477
-
478
- for (size_t entry_idx = 0; entry_idx < n_entry; entry_idx++) {
479
- size_t current_pos = offset + entry_idx;
480
- size_t batch_no = current_pos / kBatchSize;
481
- size_t pos_in_batch = current_pos % kBatchSize;
482
-
483
- const float* vector = vectors + entry_idx * d;
484
-
485
- // Compute suffix sums of squared values.
486
- suffix_sums[d] = 0.0f;
487
- for (int j = d - 1; j >= 0; j--) {
488
- float squared_val = vector[j] * vector[j];
489
- suffix_sums[j] = suffix_sums[j + 1] + squared_val;
490
- }
491
-
492
- // Store cumulative sums in batch-oriented layout.
493
- size_t cumsum_batch_offset = batch_no * kBatchSize * (n_levels + 1);
494
- float* cumsum_base = cum_sums[list_no].data();
495
-
496
- const size_t level_width_floats = level_width / sizeof(float);
497
- for (size_t level = 0; level < n_levels; level++) {
498
- size_t start_idx = level * level_width_floats;
499
- size_t cumsum_offset =
500
- cumsum_batch_offset + level * kBatchSize + pos_in_batch;
501
- if (start_idx < d) {
502
- cumsum_base[cumsum_offset] = sqrt(suffix_sums[start_idx]);
503
- } else {
504
- cumsum_base[cumsum_offset] = 0.0f;
505
- }
506
- }
507
-
508
- // Last level sum is always 0.
509
- size_t cumsum_offset =
510
- cumsum_batch_offset + n_levels * kBatchSize + pos_in_batch;
511
- cumsum_base[cumsum_offset] = 0.0f;
512
- }
513
- }
514
-
515
- // Helper method to copy codes into level-oriented batch layout at a given
516
- // offset in the list.
517
- void ArrayInvertedListsPanorama::copy_codes_to_level_layout(
518
- size_t list_no,
519
- size_t offset,
520
- size_t n_entry,
521
- const uint8_t* code) {
522
- uint8_t* codes_base = codes[list_no].data();
523
- size_t current_pos = offset;
524
- for (size_t entry_idx = 0; entry_idx < n_entry;) {
525
- // Determine which batch we're in and position within that batch.
526
- size_t batch_no = current_pos / kBatchSize;
527
- size_t pos_in_batch = current_pos % kBatchSize;
528
- size_t entries_in_this_batch =
529
- std::min(n_entry - entry_idx, kBatchSize - pos_in_batch);
530
-
531
- // Copy entries into level-oriented layout for this batch.
532
- size_t batch_offset = batch_no * kBatchSize * code_size;
533
- for (size_t level = 0; level < n_levels; level++) {
534
- size_t level_offset = level * level_width * kBatchSize;
535
- size_t start_byte = level * level_width;
536
- size_t copy_size =
537
- std::min(level_width, code_size - level * level_width);
538
-
539
- for (size_t i = 0; i < entries_in_this_batch; i++) {
540
- const uint8_t* src =
541
- code + (entry_idx + i) * code_size + start_byte;
542
- uint8_t* dest = codes_base + batch_offset + level_offset +
543
- (pos_in_batch + i) * level_width;
544
-
545
- memcpy(dest, src, copy_size);
546
- }
547
- }
548
-
549
- entry_idx += entries_in_this_batch;
550
- current_pos += entries_in_this_batch;
551
- }
552
- }
553
-
554
463
  /*****************************************************************
555
464
  * Meta-inverted list implementations
556
465
  *****************************************************************/
@@ -18,6 +18,7 @@
18
18
  #include <vector>
19
19
 
20
20
  #include <faiss/MetricType.h>
21
+ #include <faiss/impl/Panorama.h>
21
22
  #include <faiss/impl/maybe_owned_vector.h>
22
23
 
23
24
  namespace faiss {
@@ -283,6 +284,7 @@ struct ArrayInvertedListsPanorama : ArrayInvertedLists {
283
284
  std::vector<MaybeOwnedVector<float>> cum_sums;
284
285
  const size_t n_levels;
285
286
  const size_t level_width; // in code units
287
+ Panorama pano;
286
288
 
287
289
  ArrayInvertedListsPanorama(size_t nlist, size_t code_size, size_t n_levels);
288
290
 
@@ -318,24 +320,6 @@ struct ArrayInvertedListsPanorama : ArrayInvertedLists {
318
320
 
319
321
  /// Frees codes returned by `get_single_code`.
320
322
  void release_codes(size_t list_no, const uint8_t* codes) const override;
321
-
322
- private:
323
- /// Helper method to copy codes into level-oriented batch layout at a given
324
- /// offset in the list.
325
- void copy_codes_to_level_layout(
326
- size_t list_no,
327
- size_t offset,
328
- size_t n_entry,
329
- const uint8_t* code);
330
-
331
- /// Helper method to compute the cumulative sums of the codes.
332
- /// The cumsums also follow the level-oriented batch layout to minimize the
333
- /// number of random memory accesses.
334
- void compute_cumulative_sums(
335
- size_t list_no,
336
- size_t offset,
337
- size_t n_entry,
338
- const uint8_t* code);
339
323
  };
340
324
 
341
325
  /*****************************************************************
@@ -372,7 +372,7 @@ OnDiskInvertedLists::~OnDiskInvertedLists() {
372
372
  if (ptr != nullptr) {
373
373
  int err = munmap(ptr, totsize);
374
374
  if (err != 0) {
375
- fprintf(stderr, "mumap error: %s", strerror(errno));
375
+ fprintf(stderr, "munmap error: %s", strerror(errno));
376
376
  }
377
377
  }
378
378
  delete locks;
@@ -121,7 +121,7 @@ struct OnDiskInvertedLists : InvertedLists {
121
121
 
122
122
  LockLevels* locks;
123
123
 
124
- // encapsulates the threads that are busy prefeteching
124
+ // encapsulates the threads that are busy prefetching
125
125
  struct OngoingPrefetch;
126
126
  OngoingPrefetch* pf;
127
127
  int prefetch_nthread;