faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -12,217 +12,224 @@
12
12
 
13
13
  #include <faiss/MetricType.h>
14
14
  #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/impl/simd_dispatch.h>
15
16
  #include <faiss/utils/distances.h>
16
17
  #include <cmath>
17
18
  #include <type_traits>
18
19
 
19
20
  namespace faiss {
20
21
 
22
+ /***************************************************************************
23
+ * VectorDistance base class - contains common data members and type defs
24
+ **************************************************************************/
25
+
21
26
  template <MetricType mt>
22
- struct VectorDistance {
27
+ struct VectorDistanceBase {
23
28
  size_t d;
24
29
  float metric_arg;
30
+ static constexpr MetricType metric = mt;
25
31
  static constexpr bool is_similarity = is_similarity_metric(mt);
26
32
 
27
- inline float operator()(const float* x, const float* y) const;
28
-
29
- // heap template to use for this type of metric
30
33
  using C = typename std::conditional<
31
34
  is_similarity_metric(mt),
32
35
  CMin<float, int64_t>,
33
36
  CMax<float, int64_t>>::type;
34
37
  };
35
38
 
36
- template <>
37
- inline float VectorDistance<METRIC_L2>::operator()(
38
- const float* x,
39
- const float* y) const {
40
- return fvec_L2sqr(x, y, d);
41
- }
39
+ /***************************************************************************
40
+ * VectorDistance struct template - specializations for each metric type
41
+ **************************************************************************/
42
42
 
43
- template <>
44
- inline float VectorDistance<METRIC_INNER_PRODUCT>::operator()(
45
- const float* x,
46
- const float* y) const {
47
- return fvec_inner_product(x, y, d);
48
- }
43
+ template <MetricType mt, SIMDLevel level>
44
+ struct VectorDistance : VectorDistanceBase<mt> {
45
+ inline float operator()(const float* x, const float* y) const;
46
+ };
49
47
 
50
- template <>
51
- inline float VectorDistance<METRIC_L1>::operator()(
52
- const float* x,
53
- const float* y) const {
54
- return fvec_L1(x, y, d);
55
- }
48
+ template <SIMDLevel level>
49
+ struct VectorDistance<METRIC_L2, level> : VectorDistanceBase<METRIC_L2> {
50
+ inline float operator()(const float* x, const float* y) const {
51
+ return fvec_L2sqr<level>(x, y, this->d);
52
+ }
53
+ };
56
54
 
57
- template <>
58
- inline float VectorDistance<METRIC_Linf>::operator()(
59
- const float* x,
60
- const float* y) const {
61
- return fvec_Linf(x, y, d);
62
- /*
63
- float vmax = 0;
64
- for (size_t i = 0; i < d; i++) {
65
- float diff = fabs (x[i] - y[i]);
66
- if (diff > vmax) vmax = diff;
67
- }
68
- return vmax;*/
69
- }
55
+ template <SIMDLevel level>
56
+ struct VectorDistance<METRIC_INNER_PRODUCT, level>
57
+ : VectorDistanceBase<METRIC_INNER_PRODUCT> {
58
+ inline float operator()(const float* x, const float* y) const {
59
+ return fvec_inner_product<level>(x, y, this->d);
60
+ }
61
+ };
70
62
 
71
- template <>
72
- inline float VectorDistance<METRIC_Lp>::operator()(
73
- const float* x,
74
- const float* y) const {
75
- float accu = 0;
76
- for (size_t i = 0; i < d; i++) {
77
- float diff = fabs(x[i] - y[i]);
78
- accu += powf(diff, metric_arg);
63
+ template <SIMDLevel level>
64
+ struct VectorDistance<METRIC_L1, level> : VectorDistanceBase<METRIC_L1> {
65
+ inline float operator()(const float* x, const float* y) const {
66
+ return fvec_L1<level>(x, y, this->d);
79
67
  }
80
- return accu;
81
- }
68
+ };
82
69
 
83
- template <>
84
- inline float VectorDistance<METRIC_Canberra>::operator()(
85
- const float* x,
86
- const float* y) const {
87
- float accu = 0;
88
- for (size_t i = 0; i < d; i++) {
89
- float xi = x[i], yi = y[i];
90
- accu += fabs(xi - yi) / (fabs(xi) + fabs(yi));
70
+ template <SIMDLevel level>
71
+ struct VectorDistance<METRIC_Linf, level> : VectorDistanceBase<METRIC_Linf> {
72
+ inline float operator()(const float* x, const float* y) const {
73
+ return fvec_Linf<level>(x, y, this->d);
91
74
  }
92
- return accu;
93
- }
75
+ };
94
76
 
95
77
  template <>
96
- inline float VectorDistance<METRIC_BrayCurtis>::operator()(
97
- const float* x,
98
- const float* y) const {
99
- float accu_num = 0, accu_den = 0;
100
- for (size_t i = 0; i < d; i++) {
101
- float xi = x[i], yi = y[i];
102
- accu_num += fabs(xi - yi);
103
- accu_den += fabs(xi + yi);
78
+ struct VectorDistance<METRIC_Lp, SIMDLevel::NONE>
79
+ : VectorDistanceBase<METRIC_Lp> {
80
+ inline float operator()(const float* x, const float* y) const {
81
+ float accu = 0;
82
+ for (size_t i = 0; i < this->d; i++) {
83
+ float diff = fabs(x[i] - y[i]);
84
+ accu += powf(diff, this->metric_arg);
85
+ }
86
+ return accu;
104
87
  }
105
- return accu_num / accu_den;
106
- }
88
+ };
107
89
 
108
90
  template <>
109
- inline float VectorDistance<METRIC_JensenShannon>::operator()(
110
- const float* x,
111
- const float* y) const {
112
- float accu = 0;
113
- for (size_t i = 0; i < d; i++) {
114
- float xi = x[i], yi = y[i];
115
- float mi = 0.5 * (xi + yi);
116
- float kl1 = -xi * log(mi / xi);
117
- float kl2 = -yi * log(mi / yi);
118
- accu += kl1 + kl2;
91
+ struct VectorDistance<METRIC_Canberra, SIMDLevel::NONE>
92
+ : VectorDistanceBase<METRIC_Canberra> {
93
+ inline float operator()(const float* x, const float* y) const {
94
+ float accu = 0;
95
+ for (size_t i = 0; i < this->d; i++) {
96
+ float xi = x[i], yi = y[i];
97
+ accu += fabs(xi - yi) / (fabs(xi) + fabs(yi));
98
+ }
99
+ return accu;
119
100
  }
120
- return 0.5 * accu;
121
- }
101
+ };
122
102
 
123
103
  template <>
124
- inline float VectorDistance<METRIC_Jaccard>::operator()(
125
- const float* x,
126
- const float* y) const {
127
- // WARNING: this distance is defined only for positive input vectors.
128
- // Providing vectors with negative values would lead to incorrect results.
129
- float accu_num = 0, accu_den = 0;
130
- for (size_t i = 0; i < d; i++) {
131
- accu_num += fmin(x[i], y[i]);
132
- accu_den += fmax(x[i], y[i]);
104
+ struct VectorDistance<METRIC_BrayCurtis, SIMDLevel::NONE>
105
+ : VectorDistanceBase<METRIC_BrayCurtis> {
106
+ inline float operator()(const float* x, const float* y) const {
107
+ float accu_num = 0, accu_den = 0;
108
+ for (size_t i = 0; i < this->d; i++) {
109
+ float xi = x[i], yi = y[i];
110
+ accu_num += fabs(xi - yi);
111
+ accu_den += fabs(xi + yi);
112
+ }
113
+ return accu_num / accu_den;
133
114
  }
134
- return accu_num / accu_den;
135
- }
115
+ };
136
116
 
137
117
  template <>
138
- inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
139
- const float* x,
140
- const float* y) const {
141
- // https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html
142
- float accu = 0;
143
- size_t present = 0;
144
- for (size_t i = 0; i < d; i++) {
145
- if (!std::isnan(x[i]) && !std::isnan(y[i])) {
146
- float diff = x[i] - y[i];
147
- accu += diff * diff;
148
- present++;
118
+ struct VectorDistance<METRIC_JensenShannon, SIMDLevel::NONE>
119
+ : VectorDistanceBase<METRIC_JensenShannon> {
120
+ inline float operator()(const float* x, const float* y) const {
121
+ float accu = 0;
122
+ for (size_t i = 0; i < this->d; i++) {
123
+ float xi = x[i], yi = y[i];
124
+ float mi = 0.5 * (xi + yi);
125
+ float kl1 = -xi * log(mi / xi);
126
+ float kl2 = -yi * log(mi / yi);
127
+ accu += kl1 + kl2;
149
128
  }
129
+ return 0.5 * accu;
150
130
  }
151
- if (present == 0) {
152
- return NAN;
131
+ };
132
+
133
+ template <>
134
+ struct VectorDistance<METRIC_Jaccard, SIMDLevel::NONE>
135
+ : VectorDistanceBase<METRIC_Jaccard> {
136
+ inline float operator()(const float* x, const float* y) const {
137
+ // WARNING: this distance is defined only for positive input vectors.
138
+ // Providing vectors with negative values would lead to incorrect
139
+ // results.
140
+ float accu_num = 0, accu_den = 0;
141
+ for (size_t i = 0; i < this->d; i++) {
142
+ accu_num += fmin(x[i], y[i]);
143
+ accu_den += fmax(x[i], y[i]);
144
+ }
145
+ return accu_num / accu_den;
153
146
  }
154
- return float(d) / float(present) * accu;
155
- }
147
+ };
156
148
 
157
149
  template <>
158
- inline float VectorDistance<METRIC_GOWER>::operator()(
159
- const float* x,
160
- const float* y) const {
161
- float accu = 0;
162
- size_t valid_dims = 0;
163
-
164
- for (size_t i = 0; i < d; i++) {
165
- if (std::isnan(x[i]) || std::isnan(y[i])) {
166
- continue;
150
+ struct VectorDistance<METRIC_NaNEuclidean, SIMDLevel::NONE>
151
+ : VectorDistanceBase<METRIC_NaNEuclidean> {
152
+ inline float operator()(const float* x, const float* y) const {
153
+ // https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html
154
+ float accu = 0;
155
+ size_t present = 0;
156
+ for (size_t i = 0; i < this->d; i++) {
157
+ if (!std::isnan(x[i]) && !std::isnan(y[i])) {
158
+ float diff = x[i] - y[i];
159
+ accu += diff * diff;
160
+ present++;
161
+ }
162
+ }
163
+ if (present == 0) {
164
+ return NAN;
167
165
  }
166
+ return float(this->d) / float(present) * accu;
167
+ }
168
+ };
168
169
 
169
- if (x[i] >= 0 && y[i] >= 0) {
170
- if (x[i] > 1 || y[i] > 1) {
170
+ template <>
171
+ struct VectorDistance<METRIC_GOWER, SIMDLevel::NONE>
172
+ : VectorDistanceBase<METRIC_GOWER> {
173
+ inline float operator()(const float* x, const float* y) const {
174
+ float accu = 0;
175
+ size_t valid_dims = 0;
176
+
177
+ for (size_t i = 0; i < this->d; i++) {
178
+ if (std::isnan(x[i]) || std::isnan(y[i])) {
179
+ continue;
180
+ }
181
+
182
+ if (x[i] >= 0 && y[i] >= 0) {
183
+ if (x[i] > 1 || y[i] > 1) {
184
+ return std::numeric_limits<float>::quiet_NaN();
185
+ }
186
+ accu += fabs(x[i] - y[i]);
187
+ } else if (x[i] < 0 && y[i] < 0) {
188
+ accu += float(int(x[i] != y[i]));
189
+ } else {
171
190
  return std::numeric_limits<float>::quiet_NaN();
172
191
  }
173
- // Numeric dimensions are in [0,1]
174
- accu += fabs(x[i] - y[i]);
175
- } else if (x[i] < 0 && y[i] < 0) {
176
- // Categorical dimensions are negative values
177
- accu += float(int(x[i] != y[i]));
178
- } else {
179
- // Invalid representation
180
- return std::numeric_limits<float>::quiet_NaN();
192
+ valid_dims++;
181
193
  }
182
- valid_dims++;
183
- }
184
194
 
185
- if (valid_dims == 0) {
186
- return std::numeric_limits<float>::quiet_NaN();
195
+ if (valid_dims == 0) {
196
+ return std::numeric_limits<float>::quiet_NaN();
197
+ }
198
+ return accu / valid_dims;
187
199
  }
188
- return accu / valid_dims;
189
- }
200
+ };
190
201
 
191
202
  /***************************************************************************
192
- * Dispatching function that takes a metric type and a consumer object
193
- * the consumer object should contain a return type T and a operation template
194
- * function f() that is called to perform the operation. The first argument
195
- * of the function is the VectorDistance object. The rest are passed in as is.
203
+ * Dispatching function that takes a lambda directly.
204
+ * The lambda should be templated on VectorDistance, eg.:
205
+ *
206
+ * auto result = with_VectorDistance(
207
+ * metric, metric_arg, [&]<class VD>(VD vd) {
208
+ * return vd(x, y);
209
+ * });
196
210
  **************************************************************************/
197
211
 
198
- template <class Consumer, class... Types>
199
- typename Consumer::T dispatch_VectorDistance(
212
+ template <typename LambdaType>
213
+ auto with_VectorDistance(
200
214
  size_t d,
201
215
  MetricType metric,
202
216
  float metric_arg,
203
- Consumer& consumer,
204
- Types... args) {
205
- switch (metric) {
206
- #define DISPATCH_VD(mt) \
207
- case mt: { \
208
- VectorDistance<mt> vd = {d, metric_arg}; \
209
- return consumer.template f<VectorDistance<mt>>(vd, args...); \
210
- }
211
- DISPATCH_VD(METRIC_INNER_PRODUCT);
212
- DISPATCH_VD(METRIC_L2);
213
- DISPATCH_VD(METRIC_L1);
214
- DISPATCH_VD(METRIC_Linf);
215
- DISPATCH_VD(METRIC_Lp);
216
- DISPATCH_VD(METRIC_Canberra);
217
- DISPATCH_VD(METRIC_BrayCurtis);
218
- DISPATCH_VD(METRIC_JensenShannon);
219
- DISPATCH_VD(METRIC_Jaccard);
220
- DISPATCH_VD(METRIC_NaNEuclidean);
221
- DISPATCH_VD(METRIC_GOWER);
222
- default:
223
- FAISS_THROW_FMT("Invalid metric %d", metric);
224
- }
225
- #undef DISPATCH_VD
217
+ LambdaType&& action) {
218
+ auto dispatch_metric = [&]<MetricType mt>() {
219
+ auto call = [&]<SIMDLevel level>() {
220
+ VectorDistance<mt, level> vd = {d, metric_arg};
221
+ return action(vd);
222
+ };
223
+
224
+ constexpr bool has_simd = mt == METRIC_INNER_PRODUCT ||
225
+ mt == METRIC_L2 || mt == METRIC_L1 || mt == METRIC_Linf;
226
+ if constexpr (!has_simd) {
227
+ return call.template operator()<SIMDLevel::NONE>();
228
+ } else {
229
+ DISPATCH_SIMDLevel(call.template operator());
230
+ }
231
+ };
232
+ return with_metric_type(metric, dispatch_metric);
226
233
  }
227
234
 
228
235
  } // namespace faiss
@@ -11,10 +11,10 @@
11
11
 
12
12
  #include <omp.h>
13
13
  #include <algorithm>
14
- #include <cmath>
15
14
 
16
15
  #include <faiss/impl/AuxIndexStructures.h>
17
16
  #include <faiss/impl/DistanceComputer.h>
17
+ #include <faiss/impl/IDSelector.h>
18
18
  #include <faiss/utils/utils.h>
19
19
 
20
20
  namespace faiss {
@@ -25,78 +25,6 @@ namespace faiss {
25
25
 
26
26
  namespace {
27
27
 
28
- struct Run_pairwise_extra_distances {
29
- using T = void;
30
-
31
- template <class VD>
32
- void f(VD vd,
33
- int64_t nq,
34
- const float* xq,
35
- int64_t nb,
36
- const float* xb,
37
- float* dis,
38
- int64_t ldq,
39
- int64_t ldb,
40
- int64_t ldd) {
41
- #pragma omp parallel for if (nq > 10)
42
- for (int64_t i = 0; i < nq; i++) {
43
- const float* xqi = xq + i * ldq;
44
- const float* xbj = xb;
45
- float* disi = dis + ldd * i;
46
-
47
- for (int64_t j = 0; j < nb; j++) {
48
- disi[j] = vd(xqi, xbj);
49
- xbj += ldb;
50
- }
51
- }
52
- }
53
- };
54
-
55
- struct Run_knn_extra_metrics {
56
- using T = void;
57
- template <class VD>
58
- void f(VD vd,
59
- const float* x,
60
- const float* y,
61
- size_t nx,
62
- size_t ny,
63
- size_t k,
64
- float* distances,
65
- int64_t* labels) {
66
- size_t d = vd.d;
67
- using C = typename VD::C;
68
- size_t check_period = InterruptCallback::get_period_hint(ny * d);
69
- check_period *= omp_get_max_threads();
70
-
71
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
72
- size_t i1 = std::min(i0 + check_period, nx);
73
-
74
- #pragma omp parallel for
75
- for (int64_t i = i0; i < i1; i++) {
76
- const float* x_i = x + i * d;
77
- const float* y_j = y;
78
- size_t j;
79
- float* simi = distances + k * i;
80
- int64_t* idxi = labels + k * i;
81
-
82
- // maxheap_heapify(k, simi, idxi);
83
- heap_heapify<C>(k, simi, idxi);
84
- for (j = 0; j < ny; j++) {
85
- float disij = vd(x_i, y_j);
86
-
87
- if (C::cmp(simi[0], disij)) {
88
- heap_replace_top<C>(k, simi, idxi, disij, j);
89
- }
90
- y_j += d;
91
- }
92
- // maxheap_reorder(k, simi, idxi);
93
- heap_reorder<C>(k, simi, idxi);
94
- }
95
- InterruptCallback::check();
96
- }
97
- }
98
- };
99
-
100
28
  template <class VD>
101
29
  struct ExtraDistanceComputer : FlatCodesDistanceComputer {
102
30
  VD vd;
@@ -128,19 +56,6 @@ struct ExtraDistanceComputer : FlatCodesDistanceComputer {
128
56
  }
129
57
  };
130
58
 
131
- struct Run_get_distance_computer {
132
- using T = FlatCodesDistanceComputer*;
133
-
134
- template <class VD>
135
- FlatCodesDistanceComputer* f(
136
- VD vd,
137
- const float* xb,
138
- size_t nb,
139
- const float* q = nullptr) {
140
- return new ExtraDistanceComputer<VD>(vd, xb, nb, q);
141
- }
142
- };
143
-
144
59
  } // anonymous namespace
145
60
 
146
61
  void pairwise_extra_distances(
@@ -168,9 +83,19 @@ void pairwise_extra_distances(
168
83
  ldd = nb;
169
84
  }
170
85
 
171
- Run_pairwise_extra_distances run;
172
- dispatch_VectorDistance(
173
- d, mt, metric_arg, run, nq, xq, nb, xb, dis, ldq, ldb, ldd);
86
+ with_VectorDistance(d, mt, metric_arg, [&](auto vd) {
87
+ #pragma omp parallel for if (nq > 10)
88
+ for (int64_t i = 0; i < nq; i++) {
89
+ const float* xqi = xq + i * ldq;
90
+ const float* xbj = xb;
91
+ float* disi = dis + ldd * i;
92
+
93
+ for (int64_t j = 0; j < nb; j++) {
94
+ disi[j] = vd(xqi, xbj);
95
+ xbj += ldb;
96
+ }
97
+ }
98
+ });
174
99
  }
175
100
 
176
101
  void knn_extra_metrics(
@@ -183,10 +108,40 @@ void knn_extra_metrics(
183
108
  float metric_arg,
184
109
  size_t k,
185
110
  float* distances,
186
- int64_t* indexes) {
187
- Run_knn_extra_metrics run;
188
- dispatch_VectorDistance(
189
- d, mt, metric_arg, run, x, y, nx, ny, k, distances, indexes);
111
+ int64_t* indexes,
112
+ const IDSelector* sel) {
113
+ with_VectorDistance(d, mt, metric_arg, [&](auto vd) {
114
+ using C = typename decltype(vd)::C;
115
+ size_t check_period = InterruptCallback::get_period_hint(ny * d);
116
+ check_period *= omp_get_max_threads();
117
+
118
+ for (size_t i0 = 0; i0 < nx; i0 += check_period) {
119
+ size_t i1 = std::min(i0 + check_period, nx);
120
+
121
+ #pragma omp parallel for
122
+ for (int64_t i = i0; i < i1; i++) {
123
+ const float* x_i = x + i * d;
124
+ const float* y_j = y;
125
+ size_t j;
126
+ float* simi = distances + k * i;
127
+ int64_t* idxi = indexes + k * i;
128
+
129
+ heap_heapify<C>(k, simi, idxi);
130
+ for (j = 0; j < ny; j++) {
131
+ if (!sel || sel->is_member(j)) {
132
+ float disij = vd(x_i, y_j);
133
+
134
+ if (C::cmp(simi[0], disij)) {
135
+ heap_replace_top<C>(k, simi, idxi, disij, j);
136
+ }
137
+ }
138
+ y_j += d;
139
+ }
140
+ heap_reorder<C>(k, simi, idxi);
141
+ }
142
+ InterruptCallback::check();
143
+ }
144
+ });
190
145
  }
191
146
 
192
147
  FlatCodesDistanceComputer* get_extra_distance_computer(
@@ -195,8 +150,10 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
195
150
  float metric_arg,
196
151
  size_t nb,
197
152
  const float* xb) {
198
- Run_get_distance_computer run;
199
- return dispatch_VectorDistance(d, mt, metric_arg, run, xb, nb);
153
+ return with_VectorDistance(
154
+ d, mt, metric_arg, [&](auto vd) -> FlatCodesDistanceComputer* {
155
+ return new ExtraDistanceComputer<decltype(vd)>(vd, xb, nb);
156
+ });
200
157
  }
201
158
 
202
159
  } // namespace faiss
@@ -13,6 +13,7 @@
13
13
  #include <stdint.h>
14
14
 
15
15
  #include <faiss/Index.h>
16
+ #include <faiss/impl/IDSelector.h>
16
17
 
17
18
  #include <faiss/utils/Heap.h>
18
19
 
@@ -43,7 +44,8 @@ void knn_extra_metrics(
43
44
  float metric_arg,
44
45
  size_t k,
45
46
  float* distances,
46
- int64_t* indexes);
47
+ int64_t* indexes,
48
+ const IDSelector* sel = nullptr);
47
49
 
48
50
  /** get a DistanceComputer that refers to this type of distance and
49
51
  * indexes a flat array of size nb */
@@ -54,6 +56,50 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
54
56
  size_t nb,
55
57
  const float* xb);
56
58
 
59
+ /// Dispatch to a lambda with MetricType as a compile-time constant.
60
+ /// This allows writing generic code that works with different metrics
61
+ /// while maintaining compile-time optimization.
62
+ ///
63
+ /// Example usage:
64
+ /// auto result = with_metric_type(runtime_metric, [&](auto metric_tag) {
65
+ /// constexpr MetricType M = decltype(metric_tag)::value;
66
+ /// return compute_distance<M>(x, y);
67
+ /// });
68
+ #ifndef SWIG
69
+
70
+ template <typename LambdaType>
71
+ inline auto with_metric_type(MetricType metric, LambdaType&& action) {
72
+ switch (metric) {
73
+ case METRIC_INNER_PRODUCT:
74
+ return action.template operator()<METRIC_INNER_PRODUCT>();
75
+ case METRIC_L2:
76
+ return action.template operator()<METRIC_L2>();
77
+ case METRIC_L1:
78
+ return action.template operator()<METRIC_L1>();
79
+ case METRIC_Linf:
80
+ return action.template operator()<METRIC_Linf>();
81
+ case METRIC_Lp:
82
+ return action.template operator()<METRIC_Lp>();
83
+ case METRIC_Canberra:
84
+ return action.template operator()<METRIC_Canberra>();
85
+ case METRIC_BrayCurtis:
86
+ return action.template operator()<METRIC_BrayCurtis>();
87
+ case METRIC_JensenShannon:
88
+ return action.template operator()<METRIC_JensenShannon>();
89
+ case METRIC_Jaccard:
90
+ return action.template operator()<METRIC_Jaccard>();
91
+ case METRIC_NaNEuclidean:
92
+ return action.template operator()<METRIC_NaNEuclidean>();
93
+ case METRIC_GOWER:
94
+ return action.template operator()<METRIC_GOWER>();
95
+ default:
96
+ FAISS_THROW_FMT(
97
+ "with_metric_type called with unknown metric %d",
98
+ int(metric));
99
+ }
100
+ }
101
+ #endif // SWIG
102
+
57
103
  } // namespace faiss
58
104
 
59
105
  #include <faiss/utils/extra_distances-inl.h>
@@ -312,7 +312,6 @@ struct HammingComputerDefault {
312
312
  const uint8_t* a = a8 + 8 * quotient8;
313
313
  const uint8_t* b = b8 + 8 * quotient8;
314
314
  switch (remainder8) {
315
- [[fallthrough]];
316
315
  case 7:
317
316
  accu += hamdis_tab_ham_bytes[a[6] ^ b[6]];
318
317
  [[fallthrough]];
@@ -627,7 +627,7 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
627
627
  n_lt = sum_below - hist[i];
628
628
  n_gt = n - sum_below;
629
629
  } else {
630
- assert(!"not implemented");
630
+ assert(false && "not implemented");
631
631
  }
632
632
 
633
633
  IFV printf(