faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -7,6 +7,8 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <cstring>
11
+
10
12
  #include <faiss/impl/AuxIndexStructures.h>
11
13
  #include <faiss/impl/DistanceComputer.h>
12
14
  #include <faiss/impl/Quantizer.h>
@@ -39,6 +41,10 @@ struct ScalarQuantizer : Quantizer {
39
41
  QT_3bit_tqmse, ///< TurboQuant MSE-optimized, 3 bits per component
40
42
  QT_4bit_tqmse, ///< TurboQuant MSE-optimized, 4 bits per component
41
43
  QT_8bit_tqmse, ///< TurboQuant MSE-optimized, 8 bits per component
44
+ QT_2bit_tq, ///< Full TurboQuant (1-bit MSE + 1-bit QJL + factors)
45
+ QT_3bit_tq, ///< Full TurboQuant (2-bit MSE + 1-bit QJL + factors)
46
+ QT_4bit_tq, ///< Full TurboQuant (3-bit MSE + 1-bit QJL + factors)
47
+ QT_5bit_tq, ///< Full TurboQuant (4-bit MSE + 1-bit QJL + factors)
42
48
  QT_count
43
49
  };
44
50
 
@@ -131,6 +137,50 @@ struct ScalarQuantizer : Quantizer {
131
137
  }
132
138
  };
133
139
 
140
+ /// TurboQuant full (QT_*_tq) refinement state, isolated from the
141
+ /// main ScalarQuantizer to avoid polluting it with TQ-specific data.
142
+ struct TurboQuantRefine {
143
+ static bool is_turboq_full(QuantizerType qt) {
144
+ return qt >= QT_2bit_tq && qt <= QT_5bit_tq;
145
+ }
146
+
147
+ static void pack_seed(uint64_t seed, float out[2]) {
148
+ static_assert(sizeof(uint64_t) == 2 * sizeof(float));
149
+ std::memcpy(out, &seed, sizeof(uint64_t));
150
+ }
151
+
152
+ static uint64_t unpack_seed(float lo, float hi) {
153
+ float tmp[2] = {lo, hi};
154
+ uint64_t s;
155
+ static_assert(sizeof(uint64_t) == 2 * sizeof(float));
156
+ std::memcpy(&s, tmp, sizeof(uint64_t));
157
+ return s;
158
+ }
159
+
160
+ uint8_t qjl_type = 0;
161
+ uint64_t seed = 42;
162
+ size_t padded_d = 0;
163
+ std::vector<float> fwht_signs;
164
+ std::vector<float> rr_matrix;
165
+ size_t nb_bits_lo = 0;
166
+ size_t n_hi_dims = 0;
167
+
168
+ void init_projection(size_t d);
169
+ bool use_fwht() const {
170
+ return qjl_type == 0;
171
+ }
172
+
173
+ struct DistanceComputer : SQDistanceComputer {
174
+ virtual void configure(uint8_t qb, bool int_qjl) = 0;
175
+ virtual void set_prescreen_threshold(
176
+ const float* t,
177
+ bool minimize) = 0;
178
+ virtual void clear_prescreen_threshold() = 0;
179
+ };
180
+ };
181
+
182
+ TurboQuantRefine turboq_refine;
183
+
134
184
  SQDistanceComputer* get_distance_computer(
135
185
  MetricType metric = METRIC_L2) const;
136
186
 
@@ -18,19 +18,19 @@ namespace faiss {
18
18
  // A size of ~1M seems to be the threshold where the hash set wins.
19
19
  size_t visited_table_hashset_threshold = 500000;
20
20
 
21
- VisitedTable::VisitedTable(size_t size, std::optional<bool> use_hashset)
22
- : visno(use_hashset.value_or(size >= visited_table_hashset_threshold)
23
- ? 0
24
- : 1) {
25
- if (visno != 0) {
26
- visited.resize(size, 0);
21
+ std::unique_ptr<VisitedTable> VisitedTable::create(
22
+ size_t size,
23
+ std::optional<bool> use_hashset) {
24
+ bool use_set =
25
+ use_hashset.value_or(size >= visited_table_hashset_threshold);
26
+ if (use_set) {
27
+ return std::make_unique<VisitedTableSet>();
27
28
  }
29
+ return std::make_unique<VisitedTableVector>(size);
28
30
  }
29
31
 
30
- void VisitedTable::advance() {
31
- if (visno == 0) {
32
- visited_set.clear();
33
- } else if (visno < 254) {
32
+ void VisitedTableVector::advance() {
33
+ if (visno < 254) {
34
34
  // 254 rather than 255 because sometimes we use visno and visno+1
35
35
  ++visno;
36
36
  } else {
@@ -10,6 +10,7 @@
10
10
 
11
11
  #include <stdint.h>
12
12
 
13
+ #include <memory>
13
14
  #include <optional>
14
15
  #include <unordered_set>
15
16
  #include <vector>
@@ -21,54 +22,88 @@ namespace faiss {
21
22
 
22
23
  FAISS_API extern size_t visited_table_hashset_threshold;
23
24
 
24
- /// A fast, reusable Visited Set for graph search algorithms.
25
+ /// Abstract base class for a fast, reusable Visited Set for graph search
26
+ /// algorithms.
25
27
  struct VisitedTable {
26
- std::vector<uint8_t> visited;
27
- std::unordered_set<size_t> visited_set;
28
- uint8_t visno; // 0 if using visited_set, 1..250 if using vector.
28
+ virtual ~VisitedTable() = default;
29
+
30
+ /// set flag #no to true, return whether this changed it.
31
+ virtual bool set(size_t no) = 0;
32
+
33
+ /// get flag #no
34
+ virtual bool get(size_t no) const = 0;
35
+
36
+ /// prefetch flag #no
37
+ virtual void prefetch(size_t no) const = 0;
38
+
39
+ /// pre-allocate bucket space to avoid rehashing during repeated set() calls
40
+ virtual void reserve(size_t /*n*/) {}
29
41
 
30
- // If use_hashset is nullopt, the use of a hashset will be determined by
31
- // size >= visited_table_hashset_threshold.
32
- explicit VisitedTable(
42
+ /// reset all flags to false
43
+ virtual void advance() = 0;
44
+
45
+ /// Factory method to create appropriate implementation.
46
+ /// If use_hashset is nullopt, the use of a hashset will be determined by
47
+ /// size >= visited_table_hashset_threshold.
48
+ static std::unique_ptr<VisitedTable> create(
33
49
  size_t size,
34
50
  std::optional<bool> use_hashset = std::nullopt);
51
+ };
35
52
 
36
- /// set flag #no to true, return whether this changed it.
37
- bool set(size_t no) {
38
- if (visno == 0) {
39
- return visited_set.insert(no).second;
40
- } else if (visited[no] == visno) {
41
- return false;
42
- } else {
43
- visited[no] = visno;
44
- return true;
45
- }
53
+ /// Set-based implementation using unordered_set.
54
+ /// O(1) to construct and O(visits) to advance.
55
+ struct VisitedTableSet FAISS_FINAL : VisitedTable {
56
+ std::unordered_set<size_t> visited_set;
57
+
58
+ VisitedTableSet() = default;
59
+
60
+ bool set(size_t no) final {
61
+ return visited_set.insert(no).second;
46
62
  }
47
63
 
48
- /// pre-allocate bucket space to avoid rehashing during repeated set() calls
49
- void reserve(size_t n) {
50
- if (visno == 0) {
51
- visited_set.reserve(n);
52
- }
64
+ bool get(size_t no) const final {
65
+ return visited_set.count(no) != 0;
53
66
  }
54
67
 
55
- /// get flag #no
56
- bool get(size_t no) const {
57
- if (visno == 0) {
58
- return visited_set.count(no) != 0;
59
- } else {
60
- return visited[no] == visno;
61
- }
68
+ void prefetch(size_t /*no*/) const final {
69
+ // No-op for set-based implementation
62
70
  }
63
71
 
64
- void prefetch(size_t no) const {
65
- if (visno != 0) {
66
- prefetch_L2(&visited[no]);
72
+ void reserve(size_t n) final {
73
+ visited_set.reserve(n);
74
+ }
75
+
76
+ void advance() final {
77
+ visited_set.clear();
78
+ }
79
+ };
80
+
81
+ /// Vector-based implementation using a versioned byte array.
82
+ /// Faster for get()/set(), but O(size) to initialize.
83
+ /// advance() is O(1) except every 250 calls, which are O(size).
84
+ struct VisitedTableVector FAISS_FINAL : VisitedTable {
85
+ std::vector<uint8_t> visited;
86
+ uint8_t visno{1}; // Version number, 1..254
87
+
88
+ explicit VisitedTableVector(size_t size) : visited(size, 0) {}
89
+
90
+ bool set(size_t no) final {
91
+ if (visited[no] == visno) {
92
+ return false;
67
93
  }
94
+ visited[no] = visno;
95
+ return true;
68
96
  }
69
97
 
70
- /// reset all flags to false
71
- void advance();
98
+ bool get(size_t no) const final {
99
+ return visited[no] == visno;
100
+ }
101
+
102
+ void prefetch(size_t no) const final {
103
+ prefetch_L2(&visited[no]);
104
+ }
105
+
106
+ void advance() final;
72
107
  };
73
108
 
74
109
  } // namespace faiss
@@ -48,7 +48,9 @@ using namespace simd_result_handlers;
48
48
  * so callers don't need to know the handler type.
49
49
  ***************************************************************/
50
50
 
51
- template <class Handler>
51
+ // SIMDLevel SL = THE_LEVEL_TO_DISPATCH added to make the mangled
52
+ // symbol name unique per translation unit.
53
+ template <class Handler, SIMDLevel SL = THE_LEVEL_TO_DISPATCH>
52
54
  struct ScannerMixIn : FastScanCodeScanner {
53
55
  Handler handler_;
54
56
 
@@ -5,39 +5,32 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- #include <cmath>
9
-
10
8
  #include <faiss/impl/hnsw/MinimaxHeap.h>
11
9
 
12
- #include <cassert>
13
-
14
10
  #include <faiss/impl/simd_dispatch.h>
15
11
 
16
12
  namespace faiss {
17
13
 
18
- void MinimaxHeap::push(storage_idx_t i, float v) {
19
- // Treat NaN distances as infinitely far away so heap ordering is preserved.
20
- if (std::isnan(v)) {
21
- v = HC::neutral();
22
- }
23
- if (k == n) {
24
- if (v >= dis[0]) {
25
- return;
26
- }
27
- if (ids[0] != -1) {
28
- --nvalid;
29
- }
30
- faiss::heap_pop<HC>(k--, dis.data(), ids.data());
31
- }
32
- faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
33
- ++nvalid;
14
+ // Runtime-dispatched pop_min (NONE + AVX2 + AVX512 only).
15
+ constexpr int MINIMAX_HEAP_SIMD_LEVELS = (1 << int(SIMDLevel::NONE)) |
16
+ (1 << int(SIMDLevel::AVX2)) | (1 << int(SIMDLevel::AVX512));
17
+
18
+ template <class HC_>
19
+ int MinimaxHeapT<HC_>::pop_min(float* vmin_out) {
20
+ return with_selected_simd_levels<MINIMAX_HEAP_SIMD_LEVELS>(
21
+ [&]<SIMDLevel SL>() {
22
+ return pop_min_tpl<HC_, SL>(this, vmin_out);
23
+ });
34
24
  }
35
25
 
36
- // Scalar (NONE) specialization of pop_min_tpl
37
- template <>
38
- int MinimaxHeap::pop_min_tpl<SIMDLevel::NONE>(float* vmin_out) {
26
+ // Primary-template scalar implementation. Used directly when SL==NONE
27
+ template <class HC>
28
+ int pop_min_simd_none(MinimaxHeapT<HC>* heap, float* vmin_out) {
29
+ int k = heap->k;
30
+ int* ids = heap->ids.data();
31
+ float* dis = heap->dis.data();
39
32
  assert(k > 0);
40
- // returns min. This is an O(n) operation
33
+ // Returns the "best" entry. This is an O(n) operation.
41
34
  int i = k - 1;
42
35
  while (i >= 0) {
43
36
  if (ids[i] != -1) {
@@ -52,7 +45,8 @@ int MinimaxHeap::pop_min_tpl<SIMDLevel::NONE>(float* vmin_out) {
52
45
  float vmin = dis[i];
53
46
  i--;
54
47
  while (i >= 0) {
55
- if (ids[i] != -1 && dis[i] < vmin) {
48
+ // HC::cmp(vmin, dis[i]) "dis[i] is better than vmin".
49
+ if (ids[i] != -1 && HC::cmp(vmin, dis[i])) {
56
50
  vmin = dis[i];
57
51
  imin = i;
58
52
  }
@@ -63,29 +57,27 @@ int MinimaxHeap::pop_min_tpl<SIMDLevel::NONE>(float* vmin_out) {
63
57
  }
64
58
  int ret = ids[imin];
65
59
  ids[imin] = -1;
66
- --nvalid;
67
-
60
+ --heap->nvalid;
68
61
  return ret;
69
62
  }
70
63
 
71
- // Runtime-dispatched pop_min (NONE + AVX2 + AVX512 only)
72
- constexpr int MINIMAX_HEAP_SIMD_LEVELS = (1 << int(SIMDLevel::NONE)) |
73
- (1 << int(SIMDLevel::AVX2)) | (1 << int(SIMDLevel::AVX512));
74
-
75
- int MinimaxHeap::pop_min(float* vmin_out) {
76
- return with_selected_simd_levels<MINIMAX_HEAP_SIMD_LEVELS>(
77
- [&]<SIMDLevel SL>() { return pop_min_tpl<SL>(vmin_out); });
64
+ // declare for min and max heap at simd level NONE
65
+ template <>
66
+ int pop_min_tpl<CMin<float, int32_t>, SIMDLevel::NONE>(
67
+ MinimaxHeapT<CMin<float, int32_t>>* heap,
68
+ float* vmin_out) {
69
+ return pop_min_simd_none(heap, vmin_out);
78
70
  }
79
71
 
80
- int MinimaxHeap::count_below(float thresh) {
81
- int n_below = 0;
82
- for (int i = 0; i < k; i++) {
83
- if (dis[i] < thresh) {
84
- n_below++;
85
- }
86
- }
87
-
88
- return n_below;
72
+ template <>
73
+ int pop_min_tpl<CMax<float, int32_t>, SIMDLevel::NONE>(
74
+ MinimaxHeapT<CMax<float, int32_t>>* heap,
75
+ float* vmin_out) {
76
+ return pop_min_simd_none(heap, vmin_out);
89
77
  }
90
78
 
79
+ // Explicit instantiations of pop_min for the two HC variants
80
+ template int MinimaxHeapT<CMax<float, int32_t>>::pop_min(float*);
81
+ template int MinimaxHeapT<CMin<float, int32_t>>::pop_min(float*);
82
+
91
83
  } // namespace faiss
@@ -7,21 +7,30 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <cassert>
11
+ #include <cmath>
10
12
  #include <cstdint>
11
13
  #include <vector>
12
14
 
13
15
  #include <faiss/utils/Heap.h>
16
+ #include <faiss/utils/ordered_key_value.h>
14
17
  #include <faiss/utils/simd_levels.h>
15
18
 
16
19
  namespace faiss {
17
20
 
18
21
  /** Heap structure that allows fast access and updates.
19
22
  *
20
- * Supports both max-heap operations (via the underlying CMax heap)
21
- * and efficient min extraction via linear scan (with optional SIMD
22
- * acceleration).
23
+ * Templated on the comparator HC_ so that the same data structure can
24
+ * service both distance-style searches (HC_ = CMax<float, int32_t>, smaller
25
+ * is better) and similarity-style searches (HC_ = CMin<float, int32_t>,
26
+ * larger is better). For the distance variant the underlying heap is a
27
+ * max-heap and "pop_min" returns the closest element; for similarity the
28
+ * underlying heap is a min-heap and "pop_min" returns the most similar
29
+ * element.
23
30
  */
24
- struct MinimaxHeap {
31
+ template <class HC_ = CMax<float, int32_t>>
32
+ struct MinimaxHeapT {
33
+ using HC = HC_;
25
34
  using storage_idx_t = int32_t;
26
35
 
27
36
  int n;
@@ -30,12 +39,34 @@ struct MinimaxHeap {
30
39
 
31
40
  std::vector<storage_idx_t> ids;
32
41
  std::vector<float> dis;
33
- using HC = faiss::CMax<float, storage_idx_t>;
34
42
 
35
- explicit MinimaxHeap(int n_in)
43
+ explicit MinimaxHeapT(int n_in)
36
44
  : n(n_in), k(0), nvalid(0), ids(n_in), dis(n_in) {}
37
45
 
38
- void push(storage_idx_t i, float v);
46
+ void push(storage_idx_t i, float v) {
47
+ // Treat NaN distances as the "worst" value so heap ordering is
48
+ // preserved (insertion is then guaranteed to fall through the
49
+ // not-better-than-top early-reject branch when the heap is full).
50
+ if (std::isnan(v)) {
51
+ v = HC::neutral();
52
+ }
53
+ if (k == n) {
54
+ // top of the heap is the "worst" entry under HC. If the new
55
+ // value is not strictly better than the worst, drop it.
56
+ // HC::cmp(top, v) means "v is better than top" for both CMax
57
+ // (cmp = a > b → top > v → v < top) and CMin (cmp = a < b →
58
+ // top < v → v > top).
59
+ if (!HC::cmp(dis[0], v)) {
60
+ return;
61
+ }
62
+ if (ids[0] != -1) {
63
+ --nvalid;
64
+ }
65
+ faiss::heap_pop<HC>(k--, dis.data(), ids.data());
66
+ }
67
+ faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
68
+ ++nvalid;
69
+ }
39
70
 
40
71
  float max() const {
41
72
  return dis[0];
@@ -49,16 +80,34 @@ struct MinimaxHeap {
49
80
  nvalid = k = 0;
50
81
  }
51
82
 
52
- /// SIMD-templated pop_min implementation.
53
- /// Specializations exist for NONE, AVX2, and AVX512.
54
- template <SIMDLevel SL>
55
- int pop_min_tpl(float* vmin_out = nullptr);
56
-
57
- /// Runtime-dispatched pop_min (calls pop_min_tpl with best available
58
- /// SIMD level).
83
+ /// Runtime-dispatched best-element extraction (NONE + AVX2 + AVX512).
59
84
  int pop_min(float* vmin_out = nullptr);
60
85
 
61
- int count_below(float thresh);
86
+ int count_below(float thresh) {
87
+ int n_below = 0;
88
+ for (int i = 0; i < k; i++) {
89
+ // Count entries that are strictly "better than" thresh.
90
+ // HC::cmp(thresh, dis[i]) → for CMax: thresh > dis[i]
91
+ // (i.e., dis[i] < thresh, the historical L2 semantics);
92
+ // for CMin: thresh < dis[i] (similarity above threshold).
93
+ if (HC::cmp(thresh, dis[i])) {
94
+ n_below++;
95
+ }
96
+ }
97
+ return n_below;
98
+ }
62
99
  };
63
100
 
101
+ // Default `MinimaxHeap` keeps the historical max-heap semantics (smaller
102
+ // distance is better). The CMin instantiation is used when the owning
103
+ // HNSW has `is_similarity = true`. The alias itself is declared once,
104
+ // alongside the forward declaration in HNSW.h, to avoid duplicate
105
+ // `using` declarations that SWIG treats as redundant.
106
+
107
+ // Forward declarations of the SIMD specializations. The actual bodies live
108
+ // in the SIMD-specific translation units (avx2.cpp, avx512.cpp) and are
109
+ // resolved at link time.
110
+ template <class HC_, SIMDLevel SL>
111
+ int pop_min_tpl(MinimaxHeapT<HC_>* heap, float* vmin_out);
112
+
64
113
  } // namespace faiss
@@ -16,89 +16,135 @@
16
16
 
17
17
  namespace faiss {
18
18
 
19
- template <>
20
- int MinimaxHeap::pop_min_tpl<SIMDLevel::AVX2>(float* vmin_out) {
21
- assert(k > 0);
19
+ namespace {
20
+
21
+ /// Templated AVX2 implementation of "pop best" for both CMax (returns
22
+ /// the smallest distance) and CMin (returns the largest similarity).
23
+ /// The only differences between the two flavors are: (1) the initial
24
+ /// "worst possible" value, (2) the running-best update comparison
25
+ /// (`_CMP_LT_OS` vs `_CMP_GT_OS`), and (3) the tiebreaker direction.
26
+ template <class HC>
27
+ int pop_best_avx2(MinimaxHeapT<HC>& heap, float* vmin_out) {
28
+ using storage_idx_t = typename MinimaxHeapT<HC>::storage_idx_t;
22
29
  static_assert(
23
30
  std::is_same<storage_idx_t, int32_t>::value,
24
31
  "This code expects storage_idx_t to be int32_t");
32
+ assert(heap.k > 0);
33
+
34
+ // For CMax (distance) the "best" candidate is the smallest value, so
35
+ // we initialize the running best to +inf. For CMin (similarity) the
36
+ // best is the largest value, so we initialize to -inf.
37
+ constexpr float worst_v = HC::is_max
38
+ ? std::numeric_limits<float>::infinity()
39
+ : -std::numeric_limits<float>::infinity();
25
40
 
26
- int32_t min_idx = -1;
27
- float min_dis = std::numeric_limits<float>::infinity();
41
+ int32_t best_idx = -1;
42
+ float best_dis = worst_v;
28
43
 
29
44
  size_t iii = 0;
30
45
 
31
- __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
32
- __m256 min_distances =
33
- _mm256_set1_ps(std::numeric_limits<float>::infinity());
46
+ __m256i best_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
47
+ __m256 best_distances = _mm256_set1_ps(worst_v);
34
48
  __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
35
49
  __m256i offset = _mm256_set1_epi32(8);
36
50
 
37
- // The baseline version is available in the NONE specialization.
38
-
39
- // The following loop tracks the rightmost index with the min distance.
40
- // -1 index values are ignored.
41
- const size_t k8 = (k / 8) * 8;
51
+ // Track the rightmost index whose distance equals the running best.
52
+ // -1 index values are filtered out via m1mask.
53
+ const size_t k8 = (heap.k / 8) * 8;
42
54
  for (; iii < k8; iii += 8) {
43
55
  __m256i indices =
44
- _mm256_loadu_si256((const __m256i*)(ids.data() + iii));
45
- __m256 distances = _mm256_loadu_ps(dis.data() + iii);
56
+ _mm256_loadu_si256((const __m256i*)(heap.ids.data() + iii));
57
+ __m256 distances = _mm256_loadu_ps(heap.dis.data() + iii);
46
58
 
47
- // This mask filters out -1 values among indices.
59
+ // Mask out -1 indices (invalid entries).
48
60
  __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
49
61
 
50
- __m256i dmask = _mm256_castps_si256(
51
- _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
62
+ // dmask is "true where best is already (strictly) better than the
63
+ // candidate" — entries the candidate should NOT update. For CMax,
64
+ // best < candidate means we keep best (we want the smallest);
65
+ // for CMin we keep best when best > candidate (we want the largest).
66
+ __m256i dmask;
67
+ if constexpr (HC::is_max) {
68
+ dmask = _mm256_castps_si256(
69
+ _mm256_cmp_ps(best_distances, distances, _CMP_LT_OS));
70
+ } else {
71
+ dmask = _mm256_castps_si256(
72
+ _mm256_cmp_ps(best_distances, distances, _CMP_GT_OS));
73
+ }
52
74
  __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
53
75
 
54
- const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
76
+ const __m256i best_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
55
77
  _mm256_castsi256_ps(current_indices),
56
- _mm256_castsi256_ps(min_indices),
78
+ _mm256_castsi256_ps(best_indices),
57
79
  finalmask));
58
80
 
59
- const __m256 min_distances_new =
60
- _mm256_blendv_ps(distances, min_distances, finalmask);
81
+ const __m256 best_distances_new =
82
+ _mm256_blendv_ps(distances, best_distances, finalmask);
61
83
 
62
- min_indices = min_indices_new;
63
- min_distances = min_distances_new;
84
+ best_indices = best_indices_new;
85
+ best_distances = best_distances_new;
64
86
 
65
87
  current_indices = _mm256_add_epi32(current_indices, offset);
66
88
  }
67
89
 
68
- // Vectorizing is doable, but is not practical
90
+ // Vectorizing the horizontal reduction is doable but not practical.
69
91
  int32_t vidx8[8];
70
92
  float vdis8[8];
71
- _mm256_storeu_ps(vdis8, min_distances);
72
- _mm256_storeu_si256((__m256i*)vidx8, min_indices);
93
+ _mm256_storeu_ps(vdis8, best_distances);
94
+ _mm256_storeu_si256((__m256i*)vidx8, best_indices);
73
95
 
74
96
  for (size_t j = 0; j < 8; j++) {
75
- if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
76
- min_idx = vidx8[j];
77
- min_dis = vdis8[j];
97
+ const bool strictly_better =
98
+ HC::is_max ? (best_dis > vdis8[j]) : (best_dis < vdis8[j]);
99
+ if (strictly_better || (best_dis == vdis8[j] && best_idx < vidx8[j])) {
100
+ best_idx = vidx8[j];
101
+ best_dis = vdis8[j];
78
102
  }
79
103
  }
80
104
 
81
- // process last values. Vectorizing is doable, but is not practical
82
- for (; iii < static_cast<size_t>(k); iii++) {
83
- if (ids[iii] != -1 && dis[iii] <= min_dis) {
84
- min_dis = dis[iii];
85
- min_idx = iii;
105
+ // Tail (under 8 entries). Vectorizing is doable but not practical.
106
+ for (; iii < static_cast<size_t>(heap.k); iii++) {
107
+ if (heap.ids[iii] == -1) {
108
+ continue;
109
+ }
110
+ const bool weakly_better = HC::is_max ? (best_dis >= heap.dis[iii])
111
+ : (best_dis <= heap.dis[iii]);
112
+ if (weakly_better) {
113
+ best_dis = heap.dis[iii];
114
+ best_idx = iii;
86
115
  }
87
116
  }
88
117
 
89
- if (min_idx == -1) {
118
+ if (best_idx == -1) {
90
119
  return -1;
91
120
  }
92
121
 
93
122
  if (vmin_out) {
94
- *vmin_out = min_dis;
123
+ *vmin_out = best_dis;
95
124
  }
96
- int ret = ids[min_idx];
97
- ids[min_idx] = -1;
98
- --nvalid;
125
+ int ret = heap.ids[best_idx];
126
+ heap.ids[best_idx] = -1;
127
+ --heap.nvalid;
99
128
  return ret;
100
129
  }
101
130
 
131
+ } // namespace
132
+
133
+ // Explicit specializations for AVX2
134
+ template <>
135
+ int pop_min_tpl<CMax<float, int32_t>, SIMDLevel::AVX2>(
136
+ MinimaxHeapT<CMax<float, int32_t>>* heap,
137
+ float* vmin_out) {
138
+ return pop_best_avx2<CMax<float, int32_t>>(*heap, vmin_out);
139
+ }
140
+
141
+ template <>
142
+ int pop_min_tpl<CMin<float, int32_t>, SIMDLevel::AVX2>(
143
+ MinimaxHeapT<CMin<float, int32_t>>* heap,
144
+ float* vmin_out) {
145
+ return pop_best_avx2<CMin<float, int32_t>>(*heap, vmin_out);
146
+ }
147
+
102
148
  } // namespace faiss
103
149
 
104
150
  #endif // COMPILE_SIMD_AVX2