faiss 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +5 -6
  5. data/ext/faiss/index_binary.cpp +38 -28
  6. data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
  7. data/ext/faiss/kmeans.cpp +10 -9
  8. data/ext/faiss/pca_matrix.cpp +10 -8
  9. data/ext/faiss/product_quantizer.cpp +14 -12
  10. data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
  11. data/ext/faiss/{utils.h → utils_rb.h} +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  17. data/vendor/faiss/faiss/Clustering.h +12 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  19. data/vendor/faiss/faiss/Index.cpp +20 -8
  20. data/vendor/faiss/faiss/Index.h +25 -3
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  22. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  25. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  26. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  27. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  28. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  29. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  30. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  31. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  32. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  33. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  35. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  42. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  43. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  44. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  45. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  46. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  47. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  48. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  49. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  50. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  51. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  52. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  53. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  54. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  56. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  57. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  58. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  59. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  60. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  61. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  62. data/vendor/faiss/faiss/MetricType.h +16 -0
  63. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  64. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  65. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  66. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  67. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  68. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  69. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  70. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  71. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  72. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  73. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  74. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  75. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  76. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  77. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  78. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  79. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  80. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  81. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  82. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  83. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  84. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  85. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  86. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  87. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  88. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  89. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  90. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  91. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  92. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  93. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  94. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  95. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  96. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  97. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  98. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  99. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  100. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  101. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  102. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  103. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  104. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  105. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  106. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  109. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  110. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  111. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  112. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  113. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  114. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  115. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  116. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  126. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  127. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  128. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  129. data/vendor/faiss/faiss/index_io.h +29 -3
  130. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  131. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  132. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  133. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  134. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  135. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  136. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  137. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  138. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  139. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  140. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  141. data/vendor/faiss/faiss/utils/distances.h +98 -0
  142. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  143. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  144. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  145. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  146. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  147. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  148. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  149. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  150. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  151. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  152. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  157. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  158. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  159. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  160. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  161. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  162. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  163. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  164. metadata +47 -18
  165. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  166. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  167. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -0,0 +1,139 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ /**
11
+ * @file simd_dispatch.h
12
+ * @brief Internal dispatch macros for SIMD level selection.
13
+ *
14
+ * This is a PRIVATE header - do not include in public APIs or user code.
15
+ * Only faiss internal .cpp files should include this header.
16
+ *
17
+ * For the public API (SIMDLevel enum, SIMDConfig class), use:
18
+ * #include <faiss/utils/simd_levels.h>
19
+ */
20
+
21
+ #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/utils/simd_levels.h>
23
+
24
+ namespace faiss {
25
+
26
+ /*********************** x86 SIMD dispatch cases */
27
+
28
+ #ifdef COMPILE_SIMD_AVX2
29
+ #define DISPATCH_SIMDLevel_AVX2(f, ...) \
30
+ case SIMDLevel::AVX2: \
31
+ return f<SIMDLevel::AVX2>(__VA_ARGS__)
32
+ #else
33
+ #define DISPATCH_SIMDLevel_AVX2(f, ...)
34
+ #endif
35
+
36
+ #ifdef COMPILE_SIMD_AVX512
37
+ #define DISPATCH_SIMDLevel_AVX512(f, ...) \
38
+ case SIMDLevel::AVX512: \
39
+ return f<SIMDLevel::AVX512>(__VA_ARGS__)
40
+ #else
41
+ #define DISPATCH_SIMDLevel_AVX512(f, ...)
42
+ #endif
43
+
44
+ #ifdef COMPILE_SIMD_AVX512_SPR
45
+ #define DISPATCH_SIMDLevel_AVX512_SPR(f, ...) \
46
+ case SIMDLevel::AVX512_SPR: \
47
+ return f<SIMDLevel::AVX512_SPR>(__VA_ARGS__)
48
+ #else
49
+ #define DISPATCH_SIMDLevel_AVX512_SPR(f, ...)
50
+ #endif
51
+
52
+ /*********************** ARM SIMD dispatch cases */
53
+
54
+ #ifdef COMPILE_SIMD_ARM_NEON
55
+ #define DISPATCH_SIMDLevel_ARM_NEON(f, ...) \
56
+ case SIMDLevel::ARM_NEON: \
57
+ return f<SIMDLevel::ARM_NEON>(__VA_ARGS__)
58
+ #else
59
+ #define DISPATCH_SIMDLevel_ARM_NEON(f, ...)
60
+ #endif
61
+
62
+ #ifdef COMPILE_SIMD_ARM_SVE
63
+ #define DISPATCH_SIMDLevel_ARM_SVE(f, ...) \
64
+ case SIMDLevel::ARM_SVE: \
65
+ return f<SIMDLevel::ARM_SVE>(__VA_ARGS__)
66
+ #else
67
+ #define DISPATCH_SIMDLevel_ARM_SVE(f, ...)
68
+ #endif
69
+
70
+ /*********************** Main dispatch macro */
71
+
72
+ #ifdef FAISS_ENABLE_DD
73
+
74
+ // DD mode: runtime dispatch based on SIMDConfig::level
75
+ #define DISPATCH_SIMDLevel(f, ...) \
76
+ switch (SIMDConfig::level) { \
77
+ case SIMDLevel::NONE: \
78
+ return f<SIMDLevel::NONE>(__VA_ARGS__); \
79
+ DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \
80
+ DISPATCH_SIMDLevel_AVX512(f, __VA_ARGS__); \
81
+ DISPATCH_SIMDLevel_AVX512_SPR(f, __VA_ARGS__); \
82
+ DISPATCH_SIMDLevel_ARM_NEON(f, __VA_ARGS__); \
83
+ DISPATCH_SIMDLevel_ARM_SVE(f, __VA_ARGS__); \
84
+ default: \
85
+ FAISS_THROW_MSG("Invalid SIMD level"); \
86
+ }
87
+
88
+ #else // Static mode
89
+
90
+ // Static mode: direct call to compiled-in SIMD level (no runtime switch)
91
+ #if defined(COMPILE_SIMD_AVX512_SPR)
92
+ #define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX512_SPR>(__VA_ARGS__)
93
+ #elif defined(COMPILE_SIMD_AVX512)
94
+ #define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX512>(__VA_ARGS__)
95
+ #elif defined(COMPILE_SIMD_AVX2)
96
+ #define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX2>(__VA_ARGS__)
97
+ #elif defined(COMPILE_SIMD_ARM_SVE)
98
+ #define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::ARM_SVE>(__VA_ARGS__)
99
+ #elif defined(COMPILE_SIMD_ARM_NEON)
100
+ #define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::ARM_NEON>(__VA_ARGS__)
101
+ #else
102
+ #define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::NONE>(__VA_ARGS__)
103
+ #endif
104
+
105
+ #endif // FAISS_ENABLE_DD
106
+
107
+ /**
108
+ * Dispatch to a lambda with SIMDLevel as a compile-time constant.
109
+ *
110
+ * This function calls the provided templated lambda with the current
111
+ * runtime SIMD level (from SIMDConfig::level) as a compile-time template
112
+ * argument. This enables SIMD-specialized code paths while keeping the
113
+ * dispatch logic centralized.
114
+ *
115
+ * The key benefit is that the SIMD dispatch happens once, outside any loops,
116
+ * so the loop body runs with the optimal SIMD implementation without
117
+ * per-iteration dispatch overhead.
118
+ *
119
+ * Example with a loop (the dispatch happens once, not per iteration):
120
+ *
121
+ * std::vector<float> distances(n);
122
+ * with_simd_level([&]<SIMDLevel level>() {
123
+ * for (size_t i = 0; i < n; i++) {
124
+ * distances[i] = fvec_L2sqr<level>(query, vectors + i * d, d);
125
+ * }
126
+ * });
127
+ *
128
+ * The lambda must be a generic lambda with a SIMDLevel template parameter.
129
+ *
130
+ * @param action A generic lambda with signature `template<SIMDLevel> T
131
+ * operator()()`
132
+ * @return The return value of the lambda
133
+ */
134
+ template <typename LambdaType>
135
+ inline auto with_simd_level(LambdaType&& action) {
136
+ DISPATCH_SIMDLevel(action.template operator());
137
+ }
138
+
139
+ } // namespace faiss
@@ -126,8 +126,8 @@ struct StoreResultHandler : SIMDResultHandler {
126
126
 
127
127
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
128
128
  size_t ofs = (q + i0) * ld + j0 + b * 32;
129
- d0.store(data + ofs);
130
- d1.store(data + ofs + 16);
129
+ d0.storeu(data + ofs);
130
+ d1.storeu(data + ofs + 16);
131
131
  }
132
132
 
133
133
  void set_block_origin(size_t i0_in, size_t j0_in) final {
@@ -406,10 +406,10 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
406
406
  auto real_idx = this->adjust_id(b, j);
407
407
  lt_mask -= 1 << j;
408
408
  if (this->sel->is_member(real_idx)) {
409
- T dis_2 = d32tab[j];
410
- if (C::cmp(heap_dis[0], dis_2)) {
409
+ T dis_for_j = d32tab[j];
410
+ if (C::cmp(heap_dis[0], dis_for_j)) {
411
411
  heap_replace_top<C>(
412
- k, heap_dis, heap_ids, dis_2, real_idx);
412
+ k, heap_dis, heap_ids, dis_for_j, real_idx);
413
413
  nup++;
414
414
  }
415
415
  }
@@ -419,10 +419,10 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
419
419
  // find first non-zero
420
420
  int j = __builtin_ctz(lt_mask);
421
421
  lt_mask -= 1 << j;
422
- T dis_2 = d32tab[j];
423
- if (C::cmp(heap_dis[0], dis_2)) {
422
+ T dis_for_j = d32tab[j];
423
+ if (C::cmp(heap_dis[0], dis_for_j)) {
424
424
  int64_t idx = this->adjust_id(b, j);
425
- heap_replace_top<C>(k, heap_dis, heap_ids, dis_2, idx);
425
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis_for_j, idx);
426
426
  nup++;
427
427
  }
428
428
  }
@@ -524,8 +524,8 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
524
524
  auto real_idx = this->adjust_id(b, j);
525
525
  lt_mask -= 1 << j;
526
526
  if (this->sel->is_member(real_idx)) {
527
- T dis_2 = d32tab[j];
528
- res.add(dis_2, real_idx);
527
+ T dis_for_j = d32tab[j];
528
+ res.add(dis_for_j, real_idx);
529
529
  }
530
530
  }
531
531
  } else {
@@ -533,8 +533,8 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
533
533
  // find first non-zero
534
534
  int j = __builtin_ctz(lt_mask);
535
535
  lt_mask -= 1 << j;
536
- T dis_2 = d32tab[j];
537
- res.add(dis_2, this->adjust_id(b, j));
536
+ T dis_for_j = d32tab[j];
537
+ res.add(dis_for_j, this->adjust_id(b, j));
538
538
  }
539
539
  }
540
540
  }
@@ -761,12 +761,12 @@ void dispatch_SIMDResultHandler_fixedCW(
761
761
  SIMDResultHandler& res,
762
762
  Consumer& consumer,
763
763
  Types... args) {
764
- if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
765
- consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
766
- } else if (auto resh_2 = dynamic_cast<HeapHandler<C, W>*>(&res)) {
767
- consumer.template f<HeapHandler<C, W>>(*resh_2, args...);
768
- } else if (auto resh_2 = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
769
- consumer.template f<ReservoirHandler<C, W>>(*resh_2, args...);
764
+ if (auto resh_sh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
765
+ consumer.template f<SingleResultHandler<C, W>>(*resh_sh, args...);
766
+ } else if (auto resh_hh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
767
+ consumer.template f<HeapHandler<C, W>>(*resh_hh, args...);
768
+ } else if (auto resh_rh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
769
+ consumer.template f<ReservoirHandler<C, W>>(*resh_rh, args...);
770
770
  } else { // generic handler -- will not be inlined
771
771
  FAISS_THROW_IF_NOT_FMT(
772
772
  simd_result_handlers_accept_virtual,
@@ -220,6 +220,9 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
220
220
  if (match("RR([0-9]+)?")) {
221
221
  return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
222
222
  }
223
+ if (match("HR([0-9]+)?")) {
224
+ return new HadamardRotation(d, mres_to_int(sm[1], 12345));
225
+ }
223
226
  if (match("ITQ([0-9]+)?")) {
224
227
  return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
225
228
  }
@@ -585,7 +588,7 @@ SVSStorageKind parse_lvq(const std::string& lvq_string) {
585
588
  if (lvq_string == "LVQ4x8") {
586
589
  return SVSStorageKind::SVS_LVQ4x8;
587
590
  }
588
- FAISS_ASSERT(!"not supported SVS LVQ level");
591
+ FAISS_ASSERT(false && "not supported SVS LVQ level");
589
592
  }
590
593
 
591
594
  SVSStorageKind parse_leanvec(const std::string& leanvec_string) {
@@ -598,7 +601,7 @@ SVSStorageKind parse_leanvec(const std::string& leanvec_string) {
598
601
  if (leanvec_string == "LeanVec8x8") {
599
602
  return SVSStorageKind::SVS_LeanVec8x8;
600
603
  }
601
- FAISS_ASSERT(!"not supported SVS Leanvec level");
604
+ FAISS_ASSERT(false && "not supported SVS Leanvec level");
602
605
  }
603
606
 
604
607
  Index* parse_svs_datatype(
@@ -610,43 +613,49 @@ Index* parse_svs_datatype(
610
613
  std::smatch sm;
611
614
 
612
615
  if (datatype_string.empty()) {
613
- if (index_type == "Vamana")
616
+ if (index_type == "Vamana") {
614
617
  return new IndexSVSVamana(d, std::stoul(arg_string), mt);
615
- if (index_type == "Flat")
618
+ }
619
+ if (index_type == "Flat") {
616
620
  return new IndexSVSFlat(d, mt);
617
- FAISS_ASSERT(!"Unspported SVS index type");
621
+ }
622
+ FAISS_ASSERT(false && "Unspported SVS index type");
618
623
  }
619
624
  if (re_match(datatype_string, "FP16", sm)) {
620
- if (index_type == "Vamana")
625
+ if (index_type == "Vamana") {
621
626
  return new IndexSVSVamana(
622
627
  d, std::stoul(arg_string), mt, SVSStorageKind::SVS_FP16);
623
- FAISS_ASSERT(!"Unspported SVS index type for Float16");
628
+ }
629
+ FAISS_ASSERT(false && "Unspported SVS index type for Float16");
624
630
  }
625
631
  if (re_match(datatype_string, "SQI8", sm)) {
626
- if (index_type == "Vamana")
632
+ if (index_type == "Vamana") {
627
633
  return new IndexSVSVamana(
628
634
  d, std::stoul(arg_string), mt, SVSStorageKind::SVS_SQI8);
629
- FAISS_ASSERT(!"Unspported SVS index type for SQI8");
635
+ }
636
+ FAISS_ASSERT(false && "Unspported SVS index type for SQI8");
630
637
  }
631
638
  if (re_match(datatype_string, "(LVQ[0-9]+x[0-9]+)", sm)) {
632
- if (index_type == "Vamana")
639
+ if (index_type == "Vamana") {
633
640
  return new IndexSVSVamanaLVQ(
634
641
  d, std::stoul(arg_string), mt, parse_lvq(sm[0].str()));
635
- FAISS_ASSERT(!"Unspported SVS index type for LVQ");
642
+ }
643
+ FAISS_ASSERT(false && "Unspported SVS index type for LVQ");
636
644
  }
637
645
  if (re_match(datatype_string, "(LeanVec[0-9]+x[0-9]+)(_[0-9]+)?", sm)) {
638
646
  std::string leanvec_d_string =
639
647
  sm[2].length() > 0 ? sm[2].str().substr(1) : "0";
640
- int leanvec_d = std::stoul(leanvec_d_string);
648
+ int leanvec_d = static_cast<int>(std::stoul(leanvec_d_string));
641
649
 
642
- if (index_type == "Vamana")
650
+ if (index_type == "Vamana") {
643
651
  return new IndexSVSVamanaLeanVec(
644
652
  d,
645
653
  std::stoul(arg_string),
646
654
  mt,
647
655
  leanvec_d,
648
656
  parse_leanvec(sm[1].str()));
649
- FAISS_ASSERT(!"Unspported SVS index type for LeanVec");
657
+ }
658
+ FAISS_ASSERT(false && "Unspported SVS index type for LeanVec");
650
659
  }
651
660
  return nullptr;
652
661
  }
@@ -659,7 +668,6 @@ Index* parse_IndexSVS(const std::string& code_string, int d, MetricType mt) {
659
668
  return parse_svs_datatype("Flat", "", datatype_string, d, mt);
660
669
  }
661
670
  if (re_match(code_string, "Vamana([0-9]+)(,.+)?", sm)) {
662
- Index* index{nullptr};
663
671
  std::string degree_string = sm[1].str();
664
672
  std::string datatype_string =
665
673
  sm[2].length() > 0 ? sm[2].str().substr(1) : "";
@@ -667,7 +675,7 @@ Index* parse_IndexSVS(const std::string& code_string, int d, MetricType mt) {
667
675
  "Vamana", degree_string, datatype_string, d, mt);
668
676
  }
669
677
  if (re_match(code_string, "IVF([0-9]+)(,.+)?", sm)) {
670
- FAISS_ASSERT(!"Unspported SVS index type");
678
+ FAISS_ASSERT(false && "Unspported SVS index type");
671
679
  }
672
680
  return nullptr;
673
681
  }
@@ -703,6 +711,17 @@ Index* parse_other_indexes(
703
711
  }
704
712
  }
705
713
 
714
+ // IndexFlatIPPanorama
715
+ if (match("FlatIPPanorama([0-9]+)(_[0-9]+)?")) {
716
+ FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT);
717
+ int nlevels = std::stoi(sm[1].str());
718
+ if (sm[2].length() == 0) {
719
+ return new IndexFlatIPPanorama(d, nlevels);
720
+ }
721
+ int batch_size = std::stoi(sm[2].str().substr(1));
722
+ return new IndexFlatIPPanorama(d, nlevels, (size_t)batch_size);
723
+ }
724
+
706
725
  // IndexLSH
707
726
  if (match("LSH([0-9]*)(r?)(t?)")) {
708
727
  int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
@@ -11,13 +11,17 @@
11
11
  #define FAISS_INDEX_IO_H
12
12
 
13
13
  #include <cstdio>
14
+ #include <memory>
14
15
 
15
16
  /** I/O functions can read/write to a filename, a file handle or to an
16
17
  * object that abstracts the medium.
17
18
  *
18
- * The read functions return objects that should be deallocated with
19
- * delete. All references within these objects are owned by the
20
- * object.
19
+ * The read functions come in two forms:
20
+ * - read_*_up() returns a std::unique_ptr that owns the result.
21
+ * - read_*() returns a raw pointer for backward compatibility.
22
+ * The caller is responsible for deleting the returned object.
23
+ *
24
+ * All references within these objects are owned by the object.
21
25
  */
22
26
 
23
27
  namespace faiss {
@@ -68,25 +72,47 @@ Index* read_index(const char* fname, int io_flags = 0);
68
72
  Index* read_index(FILE* f, int io_flags = 0);
69
73
  Index* read_index(IOReader* reader, int io_flags = 0);
70
74
 
75
+ std::unique_ptr<Index> read_index_up(const char* fname, int io_flags = 0);
76
+ std::unique_ptr<Index> read_index_up(FILE* f, int io_flags = 0);
77
+ std::unique_ptr<Index> read_index_up(IOReader* reader, int io_flags = 0);
78
+
71
79
  IndexBinary* read_index_binary(const char* fname, int io_flags = 0);
72
80
  IndexBinary* read_index_binary(FILE* f, int io_flags = 0);
73
81
  IndexBinary* read_index_binary(IOReader* reader, int io_flags = 0);
74
82
 
83
+ std::unique_ptr<IndexBinary> read_index_binary_up(
84
+ const char* fname,
85
+ int io_flags = 0);
86
+ std::unique_ptr<IndexBinary> read_index_binary_up(FILE* f, int io_flags = 0);
87
+ std::unique_ptr<IndexBinary> read_index_binary_up(
88
+ IOReader* reader,
89
+ int io_flags = 0);
90
+
75
91
  void write_VectorTransform(const VectorTransform* vt, const char* fname);
76
92
  void write_VectorTransform(const VectorTransform* vt, IOWriter* f);
77
93
 
78
94
  VectorTransform* read_VectorTransform(const char* fname);
79
95
  VectorTransform* read_VectorTransform(IOReader* f);
80
96
 
97
+ std::unique_ptr<VectorTransform> read_VectorTransform_up(const char* fname);
98
+ std::unique_ptr<VectorTransform> read_VectorTransform_up(IOReader* f);
99
+
81
100
  ProductQuantizer* read_ProductQuantizer(const char* fname);
82
101
  ProductQuantizer* read_ProductQuantizer(IOReader* reader);
83
102
 
103
+ std::unique_ptr<ProductQuantizer> read_ProductQuantizer_up(const char* fname);
104
+ std::unique_ptr<ProductQuantizer> read_ProductQuantizer_up(IOReader* reader);
105
+
84
106
  void write_ProductQuantizer(const ProductQuantizer* pq, const char* fname);
85
107
  void write_ProductQuantizer(const ProductQuantizer* pq, IOWriter* f);
86
108
 
87
109
  void write_InvertedLists(const InvertedLists* ils, IOWriter* f);
88
110
  InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0);
89
111
 
112
+ std::unique_ptr<InvertedLists> read_InvertedLists_up(
113
+ IOReader* reader,
114
+ int io_flags = 0);
115
+
90
116
  } // namespace faiss
91
117
 
92
118
  #endif
@@ -7,6 +7,8 @@
7
7
 
8
8
  #include <faiss/invlists/BlockInvertedLists.h>
9
9
 
10
+ #include <memory>
11
+
10
12
  #include <faiss/impl/CodePacker.h>
11
13
  #include <faiss/impl/FaissAssert.h>
12
14
  #include <faiss/impl/IDSelector.h>
@@ -81,7 +83,7 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
81
83
 
82
84
  size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
83
85
  idx_t nremove = 0;
84
- #pragma omp parallel for
86
+ #pragma omp parallel for reduction(+ : nremove)
85
87
  for (idx_t i = 0; i < nlist; i++) {
86
88
  std::vector<uint8_t> buffer(packer->code_size);
87
89
  idx_t l = ids[i].size(), j = 0;
@@ -95,8 +97,9 @@ size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
95
97
  j++;
96
98
  }
97
99
  }
100
+ idx_t orig_size = ids[i].size();
98
101
  resize(i, l);
99
- nremove += ids[i].size() - l;
102
+ nremove += orig_size - l;
100
103
  }
101
104
 
102
105
  return nremove;
@@ -160,7 +163,7 @@ void BlockInvertedListsIOHook::write(const InvertedLists* ils_in, IOWriter* f)
160
163
 
161
164
  InvertedLists* BlockInvertedListsIOHook::read(IOReader* f, int /* io_flags */)
162
165
  const {
163
- BlockInvertedLists* il = new BlockInvertedLists();
166
+ auto il = std::make_unique<BlockInvertedLists>();
164
167
  READ1(il->nlist);
165
168
  READ1(il->code_size);
166
169
  READ1(il->n_per_block);
@@ -174,7 +177,7 @@ InvertedLists* BlockInvertedListsIOHook::read(IOReader* f, int /* io_flags */)
174
177
  READVECTOR(il->codes[i]);
175
178
  }
176
179
 
177
- return il;
180
+ return il.release();
178
181
  }
179
182
 
180
183
  } // namespace faiss
@@ -314,7 +314,7 @@ void OnDiskInvertedLists::update_totsize(size_t new_size) {
314
314
  slots.push_back(Slot(totsize, new_size - totsize));
315
315
  }
316
316
  } else {
317
- assert(!"not implemented");
317
+ assert(false && "not implemented");
318
318
  }
319
319
 
320
320
  totsize = new_size;
@@ -45,18 +45,6 @@
45
45
  // create svs_runtime as alias for svs::runtime::FAISS_SVS_RUNTIME_VERSION
46
46
  SVS_RUNTIME_CREATE_API_ALIAS(svs_runtime, FAISS_SVS_RUNTIME_VERSION);
47
47
 
48
- // SVS forward declarations
49
- namespace svs {
50
- namespace runtime {
51
- inline namespace v0 {
52
- struct FlatIndex;
53
- struct VamanaIndex;
54
- struct DynamicVamanaIndex;
55
- struct LeanVecTrainingData;
56
- } // namespace v0
57
- } // namespace runtime
58
- } // namespace svs
59
-
60
48
  namespace faiss {
61
49
 
62
50
  inline svs_runtime::MetricType to_svs_metric(faiss::MetricType metric) {
@@ -66,7 +54,7 @@ inline svs_runtime::MetricType to_svs_metric(faiss::MetricType metric) {
66
54
  case METRIC_L2:
67
55
  return svs_runtime::MetricType::L2;
68
56
  default:
69
- FAISS_ASSERT(!"not supported SVS distance");
57
+ FAISS_ASSERT(false && "not supported SVS distance");
70
58
  }
71
59
  }
72
60
 
@@ -93,7 +81,8 @@ template <typename T, typename U, typename = void>
93
81
  struct InputBufferConverter {
94
82
  InputBufferConverter(std::span<const U> data = {}) : buffer(data.size()) {
95
83
  FAISS_ASSERT(
96
- !"InputBufferConverter: there is no suitable user code for this type conversion");
84
+ false &&
85
+ "InputBufferConverter: there is no suitable user code for this type conversion");
97
86
  std::transform(
98
87
  data.begin(), data.end(), buffer.begin(), [](const U& val) {
99
88
  return static_cast<T>(val);
@@ -118,8 +107,8 @@ struct InputBufferConverter {
118
107
  std::vector<T> buffer;
119
108
  };
120
109
 
121
- // Specialization for reinterpret cast when types are integral and have the same
122
- // size
110
+ // Specialization for reinterpret cast when types are integral and have
111
+ // the same size
123
112
  template <typename T, typename U>
124
113
  struct InputBufferConverter<
125
114
  T,
@@ -153,7 +142,8 @@ struct OutputBufferConverter {
153
142
  OutputBufferConverter(std::span<U> data = {})
154
143
  : data_span(data), buffer(data.size()) {
155
144
  FAISS_ASSERT(
156
- !"OutputBufferConverter: there is no suitable user code for this type conversion");
145
+ false &&
146
+ "OutputBufferConverter: there is no suitable user code for this type conversion");
157
147
  }
158
148
 
159
149
  ~OutputBufferConverter() {
@@ -176,8 +166,8 @@ struct OutputBufferConverter {
176
166
  std::vector<T> buffer;
177
167
  };
178
168
 
179
- // Specialization for reinterpret cast when types are integral and have the same
180
- // size
169
+ // Specialization for reinterpret cast when types are integral and have
170
+ // the same size
181
171
  template <typename T, typename U>
182
172
  struct OutputBufferConverter<
183
173
  T,
@@ -26,6 +26,8 @@
26
26
  #include <faiss/Index.h>
27
27
  #include <faiss/svs/IndexSVSFaissUtils.h>
28
28
 
29
+ #include <svs/runtime/flat_index.h>
30
+
29
31
  #include <iostream>
30
32
 
31
33
  namespace faiss {
@@ -27,6 +27,7 @@
27
27
  #include <faiss/svs/IndexSVSFaissUtils.h>
28
28
 
29
29
  #include <svs/runtime/api_defs.h>
30
+ #include <svs/runtime/dynamic_vamana_index.h>
30
31
 
31
32
  #include <iostream>
32
33
 
@@ -71,7 +72,7 @@ inline svs_runtime::StorageKind to_svs_storage_kind(SVSStorageKind kind) {
71
72
  case SVS_LeanVec8x8:
72
73
  return svs_runtime::StorageKind::LeanVec8x8;
73
74
  default:
74
- FAISS_ASSERT(!"not supported SVS storage kind");
75
+ FAISS_ASSERT(false && "not supported SVS storage kind");
75
76
  }
76
77
  }
77
78
 
@@ -66,6 +66,14 @@ void IndexSVSVamanaLeanVec::add(idx_t n, const float* x) {
66
66
  }
67
67
 
68
68
  void IndexSVSVamanaLeanVec::train(idx_t n, const float* x) {
69
+ train(n, x, 0, nullptr);
70
+ }
71
+
72
+ void IndexSVSVamanaLeanVec::train(
73
+ idx_t n,
74
+ const float* x,
75
+ idx_t n_train_q,
76
+ const float* queries) {
69
77
  FAISS_THROW_IF_MSG(
70
78
  training_data || impl, "Index already trained or contains data.");
71
79
 
@@ -74,7 +82,7 @@ void IndexSVSVamanaLeanVec::train(idx_t n, const float* x) {
74
82
  "LVQ/LeanVec support not available on this platform or build");
75
83
 
76
84
  auto status = svs_runtime::LeanVecTrainingData::build(
77
- &training_data, d, n, x, leanvec_d);
85
+ &training_data, d, n, x, n_train_q, queries, leanvec_d);
78
86
  if (!status.ok()) {
79
87
  FAISS_THROW_MSG(status.message());
80
88
  }
@@ -41,8 +41,17 @@ struct IndexSVSVamanaLeanVec : IndexSVSVamana {
41
41
 
42
42
  void add(idx_t n, const float* x) override;
43
43
 
44
+ /* Default train assumes in-distribution data */
44
45
  void train(idx_t n, const float* x) override;
45
46
 
47
+ /* Generic train with out-of-distribution parameters.
48
+ * Out-of-distribution (OOD) means database vectors and queries _can_ be
49
+ * sampled from different distributions (e.g., cross-modal). More details in
50
+ * the original publication, arXiv:2312.16335.
51
+ */
52
+ void train(idx_t n, const float* x, idx_t n_train_q, const float* xq_train)
53
+ override;
54
+
46
55
  void serialize_training_data(std::ostream& out) const;
47
56
  void deserialize_training_data(std::istream& in);
48
57
 
@@ -254,4 +254,50 @@ INSTANTIATE(CMax, float);
254
254
  INSTANTIATE(CMin, int32_t);
255
255
  INSTANTIATE(CMax, int32_t);
256
256
 
257
+ /**********************************************************
258
+ * reorder_2_heaps
259
+ **********************************************************/
260
+
261
+ template <class C>
262
+ void reorder_2_heaps(
263
+ int64_t n,
264
+ int64_t k,
265
+ typename C::TI* __restrict labels,
266
+ float* __restrict distances,
267
+ int64_t k_base,
268
+ const typename C::TI* __restrict base_labels,
269
+ const float* __restrict base_distances) {
270
+ #pragma omp parallel for if (n > 1)
271
+ for (int64_t i = 0; i < n; i++) {
272
+ typename C::TI* idxo = labels + i * k;
273
+ float* diso = distances + i * k;
274
+ const typename C::TI* idxi = base_labels + i * k_base;
275
+ const float* disi = base_distances + i * k_base;
276
+
277
+ heap_heapify<C>(k, diso, idxo, disi, idxi, k);
278
+ if (k_base != k) { // add remaining elements
279
+ heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
280
+ }
281
+ heap_reorder<C>(k, diso, idxo);
282
+ }
283
+ }
284
+
285
+ template void reorder_2_heaps<CMax<float, int64_t>>(
286
+ int64_t n,
287
+ int64_t k,
288
+ int64_t* __restrict labels,
289
+ float* __restrict distances,
290
+ int64_t k_base,
291
+ const int64_t* __restrict base_labels,
292
+ const float* __restrict base_distances);
293
+
294
+ template void reorder_2_heaps<CMin<float, int64_t>>(
295
+ int64_t n,
296
+ int64_t k,
297
+ int64_t* __restrict labels,
298
+ float* __restrict distances,
299
+ int64_t k_base,
300
+ const int64_t* __restrict base_labels,
301
+ const float* __restrict base_distances);
302
+
257
303
  } // namespace faiss