deskit 1.2.0__tar.gz → 1.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {deskit-1.2.0/src/deskit.egg-info → deskit-1.2.2}/PKG-INFO +1 -1
- {deskit-1.2.0 → deskit-1.2.2}/pyproject.toml +1 -1
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/dewsi.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/dewsiv.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/dewst.py +5 -4
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/dewsu.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/dewsv.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/knorae.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/knoraiu.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/knorau.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/lwsei.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/lwseu.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/ola.py +2 -2
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/neighbors.py +203 -64
- {deskit-1.2.0 → deskit-1.2.2/src/deskit.egg-info}/PKG-INFO +1 -1
- {deskit-1.2.0 → deskit-1.2.2}/LICENSE +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/README.md +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/setup.cfg +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/__init__.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/_config.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/base/__init__.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/base/base.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/base/knnbase.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/base/predictbase.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/des/__init__.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/metrics.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/router.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit/utils.py +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit.egg-info/SOURCES.txt +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit.egg-info/dependency_links.txt +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit.egg-info/requires.txt +0 -0
- {deskit-1.2.0 → deskit-1.2.2}/src/deskit.egg-info/top_level.txt +0 -0
|
@@ -69,7 +69,7 @@ class DEWSI(KNNBase):
|
|
|
69
69
|
)
|
|
70
70
|
super().fit(features, y, preds_dict)
|
|
71
71
|
|
|
72
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
72
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
73
73
|
"""
|
|
74
74
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
75
75
|
Returns (batch, n_models) weight array.
|
|
@@ -79,7 +79,7 @@ class DEWSI(KNNBase):
|
|
|
79
79
|
(0.5 if self.mode == 'min' else 1.0))
|
|
80
80
|
th = threshold if threshold is not None else self.threshold
|
|
81
81
|
|
|
82
|
-
distances, indices = self.model.kneighbors(x) # both (batch, k)
|
|
82
|
+
distances, indices = self.model.kneighbors(x, k=k) # both (batch, k)
|
|
83
83
|
|
|
84
84
|
# Inverse-distance-weighted average of each model's scores over K neighbors
|
|
85
85
|
inv_dist = 1.0 / np.maximum(distances, 1e-8) # (batch, k)
|
|
@@ -84,7 +84,7 @@ class DEWSIV(KNNBase):
|
|
|
84
84
|
preds = np.asarray(preds_dict[name])
|
|
85
85
|
self._var_matrix[:, j] = np.vectorize(_signed_residual)(y, preds)
|
|
86
86
|
|
|
87
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
87
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
88
88
|
"""
|
|
89
89
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
90
90
|
Returns (batch, n_models) weight array.
|
|
@@ -94,7 +94,7 @@ class DEWSIV(KNNBase):
|
|
|
94
94
|
(0.5 if self.mode == 'min' else 1.0))
|
|
95
95
|
th = threshold if threshold is not None else self.threshold
|
|
96
96
|
|
|
97
|
-
distances, indices = self.model.kneighbors(x) # both (batch, k)
|
|
97
|
+
distances, indices = self.model.kneighbors(x, k=k) # both (batch, k)
|
|
98
98
|
|
|
99
99
|
# Inverse-distance weights
|
|
100
100
|
inv_dist = 1.0 / np.maximum(distances, 1e-8) # (batch, k)
|
|
@@ -86,7 +86,7 @@ class DEWST(KNNBase):
|
|
|
86
86
|
)
|
|
87
87
|
super().fit(features, y, preds_dict)
|
|
88
88
|
|
|
89
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
89
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None, r2_threshold=None):
|
|
90
90
|
"""
|
|
91
91
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
92
92
|
Returns (batch, n_models) weight array.
|
|
@@ -94,9 +94,10 @@ class DEWST(KNNBase):
|
|
|
94
94
|
t = temperature if temperature is not None else (
|
|
95
95
|
self._temperature if self._temperature is not None else
|
|
96
96
|
(0.5 if self._real_mode == 'min' else 1.0))
|
|
97
|
-
th
|
|
97
|
+
th = threshold if threshold is not None else self.threshold
|
|
98
|
+
r2_th = r2_threshold if r2_threshold is not None else self.r2_threshold
|
|
98
99
|
|
|
99
|
-
distances, indices = self.model.kneighbors(x)
|
|
100
|
+
distances, indices = self.model.kneighbors(x, k=k) # (batch, k)
|
|
100
101
|
k = distances.shape[1]
|
|
101
102
|
|
|
102
103
|
# Inverse-distance weights
|
|
@@ -157,7 +158,7 @@ class DEWST(KNNBase):
|
|
|
157
158
|
trend_scores = intercept
|
|
158
159
|
|
|
159
160
|
# Blend: trust trend where R² ≥ threshold, fall back otherwise
|
|
160
|
-
use_trend = r2 >=
|
|
161
|
+
use_trend = r2 >= r2_th
|
|
161
162
|
avg_scores = np.where(use_trend, trend_scores, dewsi_scores)
|
|
162
163
|
|
|
163
164
|
# Standard DEWS softmax
|
|
@@ -65,7 +65,7 @@ class DEWSU(KNNBase):
|
|
|
65
65
|
)
|
|
66
66
|
super().fit(features, y, preds_dict)
|
|
67
67
|
|
|
68
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
68
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
69
69
|
"""
|
|
70
70
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
71
71
|
Returns (batch, n_models) weight array.
|
|
@@ -75,7 +75,7 @@ class DEWSU(KNNBase):
|
|
|
75
75
|
(0.5 if self.mode == 'min' else 1.0))
|
|
76
76
|
th = threshold if threshold is not None else self.threshold
|
|
77
77
|
|
|
78
|
-
_, indices = self.model.kneighbors(x) # (batch, k)
|
|
78
|
+
_, indices = self.model.kneighbors(x, k=k) # (batch, k)
|
|
79
79
|
|
|
80
80
|
# Average each model's scores over the K neighbors
|
|
81
81
|
avg_scores = self.matrix[indices].mean(axis=1) # (batch, n_models)
|
|
@@ -84,7 +84,7 @@ class DEWSV(KNNBase):
|
|
|
84
84
|
preds = np.asarray(preds_dict[name])
|
|
85
85
|
self._var_matrix[:, j] = np.vectorize(_signed_residual)(y, preds)
|
|
86
86
|
|
|
87
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
87
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
88
88
|
"""
|
|
89
89
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
90
90
|
Returns (batch, n_models) weight array.
|
|
@@ -94,7 +94,7 @@ class DEWSV(KNNBase):
|
|
|
94
94
|
(0.5 if self.mode == 'min' else 1.0))
|
|
95
95
|
th = threshold if threshold is not None else self.threshold
|
|
96
96
|
|
|
97
|
-
_, indices = self.model.kneighbors(x) # (batch, k)
|
|
97
|
+
_, indices = self.model.kneighbors(x, k=k) # (batch, k)
|
|
98
98
|
|
|
99
99
|
# Uniform average of each model's scores over K neighbors
|
|
100
100
|
neighbor_scores = self.matrix[indices] # (batch, k, n_models)
|
|
@@ -55,7 +55,7 @@ class KNORAE(KNNBase):
|
|
|
55
55
|
)
|
|
56
56
|
super().fit(features, y, preds_dict)
|
|
57
57
|
|
|
58
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
58
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
59
59
|
"""
|
|
60
60
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
61
61
|
Returns (batch, n_models) weight array.
|
|
@@ -64,7 +64,7 @@ class KNORAE(KNNBase):
|
|
|
64
64
|
th = threshold if threshold is not None else self.threshold
|
|
65
65
|
n_models = len(self.models)
|
|
66
66
|
|
|
67
|
-
_, indices = self.model.kneighbors(x)
|
|
67
|
+
_, indices = self.model.kneighbors(x, k=k)
|
|
68
68
|
k = indices.shape[1]
|
|
69
69
|
neighbor_scores = self.matrix[indices] # (batch, k, n_models)
|
|
70
70
|
|
|
@@ -56,7 +56,7 @@ class KNORAIU(KNNBase):
|
|
|
56
56
|
)
|
|
57
57
|
super().fit(features, y, preds_dict)
|
|
58
58
|
|
|
59
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
59
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
60
60
|
"""
|
|
61
61
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
62
62
|
Returns (batch, n_models) weight array.
|
|
@@ -64,7 +64,7 @@ class KNORAIU(KNNBase):
|
|
|
64
64
|
"""
|
|
65
65
|
th = threshold if threshold is not None else self.threshold
|
|
66
66
|
|
|
67
|
-
distances, indices = self.model.kneighbors(x) # both (batch, k)
|
|
67
|
+
distances, indices = self.model.kneighbors(x, k=k) # both (batch, k)
|
|
68
68
|
neighbor_scores = self.matrix[indices] # (batch, k, n_models)
|
|
69
69
|
|
|
70
70
|
# Normalize per neighbor: best model = 1.0, worst = 0.0
|
|
@@ -56,7 +56,7 @@ class KNORAU(KNNBase):
|
|
|
56
56
|
)
|
|
57
57
|
super().fit(features, y, preds_dict)
|
|
58
58
|
|
|
59
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
59
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
60
60
|
"""
|
|
61
61
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
62
62
|
Returns (batch, n_models) weight array.
|
|
@@ -64,7 +64,7 @@ class KNORAU(KNNBase):
|
|
|
64
64
|
"""
|
|
65
65
|
th = threshold if threshold is not None else self.threshold
|
|
66
66
|
|
|
67
|
-
_, indices = self.model.kneighbors(x)
|
|
67
|
+
_, indices = self.model.kneighbors(x, k=k)
|
|
68
68
|
neighbor_scores = self.matrix[indices] # (batch, k, n_models)
|
|
69
69
|
|
|
70
70
|
# Normalize per neighbor: best model = 1.0, worst = 0.0
|
|
@@ -74,7 +74,7 @@ class LWSEI(PredictBase):
|
|
|
74
74
|
self._y_val = y
|
|
75
75
|
self._finder.fit(features)
|
|
76
76
|
|
|
77
|
-
def _weights_batch(self, x, temperature=None, **kwargs):
|
|
77
|
+
def _weights_batch(self, x, temperature=None, k=None, **kwargs):
|
|
78
78
|
"""
|
|
79
79
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
80
80
|
Returns (batch, n_models) weight array.
|
|
@@ -87,7 +87,7 @@ class LWSEI(PredictBase):
|
|
|
87
87
|
n_models = len(self.models)
|
|
88
88
|
uniform = np.full(n_models, 1.0 / n_models)
|
|
89
89
|
|
|
90
|
-
distances, indices = self._finder.kneighbors(x) # (batch, k)
|
|
90
|
+
distances, indices = self._finder.kneighbors(x, k=k) # (batch, k)
|
|
91
91
|
weights_out = np.empty((batch_size, n_models))
|
|
92
92
|
|
|
93
93
|
for b in range(batch_size):
|
|
@@ -74,7 +74,7 @@ class LWSEU(PredictBase):
|
|
|
74
74
|
self._y_val = y
|
|
75
75
|
self._finder.fit(features)
|
|
76
76
|
|
|
77
|
-
def _weights_batch(self, x, temperature=None, **kwargs):
|
|
77
|
+
def _weights_batch(self, x, temperature=None, k=None, **kwargs):
|
|
78
78
|
"""
|
|
79
79
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
80
80
|
Returns (batch, n_models) weight array.
|
|
@@ -87,7 +87,7 @@ class LWSEU(PredictBase):
|
|
|
87
87
|
n_models = len(self.models)
|
|
88
88
|
uniform = np.full(n_models, 1.0 / n_models)
|
|
89
89
|
|
|
90
|
-
distances, indices = self._finder.kneighbors(x) # (batch, k)
|
|
90
|
+
distances, indices = self._finder.kneighbors(x, k=k) # (batch, k)
|
|
91
91
|
weights_out = np.empty((batch_size, n_models))
|
|
92
92
|
|
|
93
93
|
for b in range(batch_size):
|
|
@@ -54,7 +54,7 @@ class OLA(KNNBase):
|
|
|
54
54
|
if mat_max > mat_min:
|
|
55
55
|
self.matrix = (self.matrix - mat_min) / (mat_max - mat_min)
|
|
56
56
|
|
|
57
|
-
def _weights_batch(self, x, temperature=None, threshold=None):
|
|
57
|
+
def _weights_batch(self, x, temperature=None, threshold=None, k=None):
|
|
58
58
|
"""
|
|
59
59
|
Core weight computation. x is a 2-D float64 numpy array (batch, n_features).
|
|
60
60
|
Returns (batch, n_models) weight array.
|
|
@@ -63,7 +63,7 @@ class OLA(KNNBase):
|
|
|
63
63
|
"""
|
|
64
64
|
batch_size = x.shape[0]
|
|
65
65
|
|
|
66
|
-
_, indices = self.model.kneighbors(x)
|
|
66
|
+
_, indices = self.model.kneighbors(x, k=k)
|
|
67
67
|
avg_scores = self.matrix[indices].mean(axis=1) # (batch, n_models)
|
|
68
68
|
best_indices = np.argmax(avg_scores, axis=1)
|
|
69
69
|
|
|
@@ -13,18 +13,28 @@ _FAISS_MIN_SAMPLES_PER_CELL = 40
|
|
|
13
13
|
# 'euclidean' is the universal default and always available.
|
|
14
14
|
#
|
|
15
15
|
# Choosing a distance metric:
|
|
16
|
-
# euclidean
|
|
17
|
-
# manhattan
|
|
18
|
-
#
|
|
19
|
-
#
|
|
20
|
-
# chebyshev
|
|
21
|
-
#
|
|
22
|
-
#
|
|
23
|
-
# minkowski
|
|
24
|
-
#
|
|
25
|
-
#
|
|
26
|
-
#
|
|
27
|
-
#
|
|
16
|
+
# euclidean – The standard L2 norm. Best default for most tabular data.
|
|
17
|
+
# manhattan – L1 norm (sum of absolute differences). More robust to
|
|
18
|
+
# outliers and tends to work better in moderately high-
|
|
19
|
+
# dimensional spaces because it doesn't square large diffs.
|
|
20
|
+
# chebyshev – L∞ norm (maximum absolute difference across features).
|
|
21
|
+
# Useful when a single feature dominating the distance is
|
|
22
|
+
# acceptable; common in game-grid / chess-style problems.
|
|
23
|
+
# minkowski – Generalisation of L1/L2 (controlled by p). p=1 →
|
|
24
|
+
# manhattan, p=2 → euclidean. Use when you want to tune
|
|
25
|
+
# between them.
|
|
26
|
+
# cosine – Angle between vectors, ignoring magnitude. Excellent for
|
|
27
|
+
# embeddings (text, image, audio) where direction matters
|
|
28
|
+
# more than raw scale.
|
|
29
|
+
# canberra – Weighted L1. Sensitive to small values near zero.
|
|
30
|
+
# braycurtis – Normalised L1 bounded to [0,1]. Common in ecology.
|
|
31
|
+
# jensenshannon – Symmetric KL divergence on probability distributions.
|
|
32
|
+
# Requires non-negative vectors. Supported by FAISS flat/
|
|
33
|
+
# HNSW/GPU indices natively.
|
|
34
|
+
# dot – Raw inner/dot product. Not a true metric; distances are
|
|
35
|
+
# not comparable across queries. Use for max inner-product
|
|
36
|
+
# search (recommendation systems). Prefer 'cosine' for
|
|
37
|
+
# normalised embeddings.
|
|
28
38
|
|
|
29
39
|
# Metrics that every backend supports natively.
|
|
30
40
|
_UNIVERSAL_METRICS = {'euclidean', 'manhattan', 'chebyshev', 'minkowski', 'cosine'}
|
|
@@ -33,31 +43,54 @@ _UNIVERSAL_METRICS = {'euclidean', 'manhattan', 'chebyshev', 'minkowski', 'cosin
|
|
|
33
43
|
# KNN (sklearn) supports all scipy metrics — this is the complete curated list.
|
|
34
44
|
_KNN_METRICS = _UNIVERSAL_METRICS | {'correlation', 'hamming', 'canberra', 'braycurtis'}
|
|
35
45
|
|
|
36
|
-
# FAISS
|
|
37
|
-
#
|
|
38
|
-
|
|
39
|
-
|
|
46
|
+
# FAISS native metric support:
|
|
47
|
+
# IndexFlat, IndexHNSW, and GpuIndexFlat support METRIC_L1, METRIC_Linf,
|
|
48
|
+
# METRIC_Lp (with metric_arg for p), METRIC_Canberra, METRIC_BrayCurtis,
|
|
49
|
+
# and METRIC_JensenShannon in addition to L2 and inner product.
|
|
50
|
+
# IndexIVFFlat only supports L2 and inner product.
|
|
51
|
+
# 'ivf' index_type will still fall back for non-L2/cosine metrics.
|
|
52
|
+
_FAISS_FLAT_HNSW_NATIVE_METRICS = {
|
|
53
|
+
'euclidean', 'cosine', 'manhattan', 'chebyshev', 'minkowski',
|
|
54
|
+
'canberra', 'braycurtis', 'jensenshannon',
|
|
55
|
+
}
|
|
56
|
+
_FAISS_IVF_NATIVE_METRICS = {'euclidean', 'cosine'}
|
|
57
|
+
|
|
58
|
+
# For backwards compatibility: the overall set accepted by FaissNeighborFinder.
|
|
59
|
+
_FAISS_METRICS = _FAISS_FLAT_HNSW_NATIVE_METRICS | {'correlation', 'hamming'}
|
|
40
60
|
|
|
41
61
|
# Annoy metric names (library-specific).
|
|
62
|
+
# Annoy natively supports: euclidean, manhattan, cosine (angular), hamming,
|
|
63
|
+
# and dot (inner product). chebyshev and minkowski have no Annoy equivalent.
|
|
42
64
|
_ANNOY_METRIC_MAP = {
|
|
43
65
|
'euclidean': 'euclidean',
|
|
44
66
|
'manhattan': 'manhattan',
|
|
45
67
|
'cosine': 'angular',
|
|
46
68
|
'hamming': 'hamming',
|
|
47
|
-
|
|
69
|
+
'dot': 'dot',
|
|
48
70
|
}
|
|
49
71
|
_ANNOY_METRICS = set(_ANNOY_METRIC_MAP)
|
|
50
72
|
|
|
51
|
-
#
|
|
52
|
-
|
|
73
|
+
# hnswlib space names — only three native spaces exist.
|
|
74
|
+
# 'ip' is inner product (not a true metric; used for max inner-product search).
|
|
75
|
+
_HNSWLIB_METRIC_MAP = {
|
|
53
76
|
'euclidean': 'l2',
|
|
54
77
|
'cosine': 'cosine',
|
|
55
|
-
|
|
78
|
+
'dot': 'ip',
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# nmslib space names for DENSE_VECTOR + HNSW.
|
|
82
|
+
# l1/linf/angulardist are confirmed supported by nmslib's integration tests.
|
|
83
|
+
# 'dot' maps to negdotprod (nmslib maximises inner product via negative distance).
|
|
84
|
+
_NMSLIB_METRIC_MAP = {
|
|
85
|
+
'euclidean': 'l2',
|
|
86
|
+
'cosine': 'cosinesimil',
|
|
87
|
+
'manhattan': 'l1',
|
|
88
|
+
'chebyshev': 'linf',
|
|
89
|
+
'dot': 'negdotprod',
|
|
56
90
|
}
|
|
57
|
-
_HNSW_METRICS = _UNIVERSAL_METRICS # partial — see fit() for fallback note
|
|
58
91
|
|
|
59
92
|
# All metrics callable from the public API.
|
|
60
|
-
ALL_METRICS = _KNN_METRICS
|
|
93
|
+
ALL_METRICS = _KNN_METRICS | {'jensenshannon', 'dot'}
|
|
61
94
|
|
|
62
95
|
|
|
63
96
|
def list_distance_metrics():
|
|
@@ -65,19 +98,21 @@ def list_distance_metrics():
|
|
|
65
98
|
print("\nAvailable Distance Metrics:")
|
|
66
99
|
print("=" * 70)
|
|
67
100
|
rows = [
|
|
68
|
-
("euclidean",
|
|
69
|
-
("manhattan",
|
|
70
|
-
("chebyshev",
|
|
71
|
-
("minkowski",
|
|
72
|
-
("cosine",
|
|
73
|
-
("
|
|
74
|
-
("
|
|
75
|
-
("
|
|
76
|
-
("
|
|
101
|
+
("euclidean", "Default. L2 norm. Best for most tabular data.", "all"),
|
|
102
|
+
("manhattan", "L1 norm. More robust to outliers; good for high-dim data.", "KNN, FAISS (flat/hnsw), Annoy, HNSW-nmslib"),
|
|
103
|
+
("chebyshev", "L∞ norm. Max absolute diff across features.", "KNN, FAISS (flat/hnsw), HNSW-nmslib"),
|
|
104
|
+
("minkowski", "Generalises L1/L2 via p-param. Set minkowski_p=<float>.", "KNN, FAISS (flat/hnsw)"),
|
|
105
|
+
("cosine", "Angle between vectors. Ideal for embeddings (NLP, vision).", "all"),
|
|
106
|
+
("dot", "Inner/dot product. Not a metric; used for max-IP search.", "Annoy, HNSW (hnswlib ip / nmslib negdotprod)"),
|
|
107
|
+
("canberra", "Weighted L1. Sensitive to small values near zero.", "KNN, FAISS (flat/hnsw/gpu)"),
|
|
108
|
+
("braycurtis", "Normalised L1 bounded to [0,1]. Ecological data.", "KNN, FAISS (flat/hnsw/gpu)"),
|
|
109
|
+
("jensenshannon", "Symmetric KL divergence. Requires non-negative vectors.", "FAISS (flat/hnsw/gpu)"),
|
|
110
|
+
("correlation", "Pearson correlation distance. Good for time series.", "KNN only"),
|
|
111
|
+
("hamming", "Fraction of differing components. For binary/categorical data.", "KNN, Annoy"),
|
|
77
112
|
]
|
|
78
113
|
for name, desc, backends in rows:
|
|
79
|
-
print(f"\n {name:<
|
|
80
|
-
print(f" {'':
|
|
114
|
+
print(f"\n {name:<16} {desc}")
|
|
115
|
+
print(f" {'':16} Backends: {backends}")
|
|
81
116
|
print("\n" + "=" * 70)
|
|
82
117
|
|
|
83
118
|
|
|
@@ -160,21 +195,36 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
160
195
|
"""
|
|
161
196
|
Approximate nearest neighbors via FAISS (flat, IVF, or HNSW index).
|
|
162
197
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
198
|
+
Native metric support depends on index_type:
|
|
199
|
+
|
|
200
|
+
flat / hnsw / gpu-flat
|
|
201
|
+
FAISS IndexFlat, IndexHNSW, and GpuIndexFlat natively support:
|
|
202
|
+
euclidean, cosine, manhattan (L1), chebyshev (Linf), minkowski (Lp),
|
|
203
|
+
canberra, braycurtis, jensenshannon.
|
|
204
|
+
|
|
205
|
+
ivf
|
|
206
|
+
IndexIVFFlat only supports L2 and inner-product (cosine). All other
|
|
207
|
+
metrics fall back to an exact sklearn KNN with a warning.
|
|
208
|
+
|
|
209
|
+
correlation and hamming always fall back to sklearn for all index types.
|
|
166
210
|
"""
|
|
167
211
|
|
|
168
212
|
def __init__(self, k=10, index_type='flat', n_cells=None, n_probes=50,
|
|
169
213
|
hnsw_M=32, hnsw_efConstruction=400, hnsw_efSearch=200,
|
|
170
|
-
distance_metric='euclidean'):
|
|
214
|
+
distance_metric='euclidean', minkowski_p=2):
|
|
171
215
|
"""
|
|
172
216
|
Parameters
|
|
173
217
|
----------
|
|
174
218
|
distance_metric : str
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
219
|
+
Metric to use. flat/hnsw/gpu index types natively support:
|
|
220
|
+
'euclidean', 'cosine', 'manhattan', 'chebyshev', 'minkowski',
|
|
221
|
+
'canberra', 'braycurtis', 'jensenshannon'.
|
|
222
|
+
'ivf' only natively supports 'euclidean' and 'cosine'; all others
|
|
223
|
+
fall back to exact sklearn KNN with a warning.
|
|
224
|
+
'correlation' and 'hamming' always fall back to sklearn.
|
|
225
|
+
minkowski_p : float
|
|
226
|
+
The p-parameter for the Minkowski metric. p=1 → manhattan,
|
|
227
|
+
p=2 → euclidean. Ignored for all other metrics.
|
|
178
228
|
"""
|
|
179
229
|
if k <= 0:
|
|
180
230
|
raise ValueError(f"k must be positive, got k={k}")
|
|
@@ -193,6 +243,7 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
193
243
|
self.hnsw_efConstruction = hnsw_efConstruction
|
|
194
244
|
self.hnsw_efSearch = hnsw_efSearch
|
|
195
245
|
self.distance_metric = metric
|
|
246
|
+
self.minkowski_p = minkowski_p
|
|
196
247
|
self.index_ = None
|
|
197
248
|
self._fallback_finder = None # used for non-native metrics
|
|
198
249
|
self._check_availability()
|
|
@@ -204,6 +255,31 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
204
255
|
except ImportError:
|
|
205
256
|
raise ImportError("FAISS not found. Install with: pip install faiss-cpu")
|
|
206
257
|
|
|
258
|
+
@staticmethod
|
|
259
|
+
def _faiss_metric_type(faiss, metric, minkowski_p=2):
|
|
260
|
+
"""
|
|
261
|
+
Return (faiss_metric_constant, metric_arg) for a given metric name.
|
|
262
|
+
metric_arg is only meaningful for METRIC_Lp (minkowski).
|
|
263
|
+
Raises ValueError for metrics that have no FAISS MetricType constant
|
|
264
|
+
(i.e. those that must be handled via fallback).
|
|
265
|
+
"""
|
|
266
|
+
_MAP = {
|
|
267
|
+
'euclidean': (faiss.METRIC_L2, None),
|
|
268
|
+
'cosine': (faiss.METRIC_INNER_PRODUCT, None),
|
|
269
|
+
'manhattan': (faiss.METRIC_L1, None),
|
|
270
|
+
'chebyshev': (faiss.METRIC_Linf, None),
|
|
271
|
+
'minkowski': (faiss.METRIC_Lp, None), # metric_arg set below
|
|
272
|
+
'canberra': (faiss.METRIC_Canberra, None),
|
|
273
|
+
'braycurtis': (faiss.METRIC_BrayCurtis, None),
|
|
274
|
+
'jensenshannon':(faiss.METRIC_JensenShannon, None),
|
|
275
|
+
}
|
|
276
|
+
if metric not in _MAP:
|
|
277
|
+
raise KeyError(metric)
|
|
278
|
+
ft, arg = _MAP[metric]
|
|
279
|
+
if metric == 'minkowski':
|
|
280
|
+
arg = float(minkowski_p)
|
|
281
|
+
return ft, arg
|
|
282
|
+
|
|
207
283
|
@staticmethod
|
|
208
284
|
def _l2_normalize(X):
|
|
209
285
|
"""Row-wise L2 normalisation in-place (for cosine similarity)."""
|
|
@@ -221,11 +297,18 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
221
297
|
f"{n_samples} samples. Reduce k to at most {n_samples}."
|
|
222
298
|
)
|
|
223
299
|
|
|
224
|
-
#
|
|
225
|
-
|
|
300
|
+
# Determine whether the chosen metric is natively supported by this index type.
|
|
301
|
+
ivf_native = self.distance_metric in _FAISS_IVF_NATIVE_METRICS
|
|
302
|
+
flat_hnsw_native = self.distance_metric in _FAISS_FLAT_HNSW_NATIVE_METRICS
|
|
303
|
+
is_ivf = (self.index_type == 'ivf')
|
|
304
|
+
|
|
305
|
+
needs_fallback = is_ivf and not ivf_native
|
|
306
|
+
needs_fallback = needs_fallback or (not is_ivf and not flat_hnsw_native)
|
|
307
|
+
|
|
308
|
+
if needs_fallback:
|
|
226
309
|
warnings.warn(
|
|
227
310
|
f"distance_metric='{self.distance_metric}' is not natively supported by "
|
|
228
|
-
f"FAISS. Falling back to exact sklearn KNN
|
|
311
|
+
f"FAISS {self.index_type} index. Falling back to exact sklearn KNN. "
|
|
229
312
|
f"Use preset='exact' to avoid this overhead.",
|
|
230
313
|
UserWarning,
|
|
231
314
|
)
|
|
@@ -235,7 +318,7 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
235
318
|
self._fallback_finder.fit(X)
|
|
236
319
|
return self
|
|
237
320
|
|
|
238
|
-
# Cosine similarity: normalise all vectors
|
|
321
|
+
# Cosine similarity: normalise all vectors so inner-product == cosine.
|
|
239
322
|
if self.distance_metric == 'cosine':
|
|
240
323
|
X = self._l2_normalize(X)
|
|
241
324
|
|
|
@@ -247,12 +330,20 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
247
330
|
UserWarning
|
|
248
331
|
)
|
|
249
332
|
if self.distance_metric == 'cosine':
|
|
250
|
-
self.index_ = self.faiss.IndexFlatIP(dim)
|
|
251
|
-
|
|
333
|
+
self.index_ = self.faiss.IndexFlatIP(dim)
|
|
334
|
+
elif self.distance_metric == 'euclidean':
|
|
252
335
|
self.index_ = self.faiss.IndexFlatL2(dim)
|
|
336
|
+
else:
|
|
337
|
+
ft, metric_arg = self._faiss_metric_type(
|
|
338
|
+
self.faiss, self.distance_metric, self.minkowski_p
|
|
339
|
+
)
|
|
340
|
+
self.index_ = self.faiss.IndexFlat(dim, ft)
|
|
341
|
+
if metric_arg is not None:
|
|
342
|
+
self.index_.metric_arg = metric_arg
|
|
253
343
|
self.index_.add(X)
|
|
254
344
|
|
|
255
345
|
elif self.index_type == 'ivf':
|
|
346
|
+
# IVF only supports L2 / inner-product (guarded above).
|
|
256
347
|
if self.n_cells is None:
|
|
257
348
|
self.n_cells = min(int(np.sqrt(n_samples)), 4096)
|
|
258
349
|
|
|
@@ -300,7 +391,19 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
300
391
|
f"{n_samples} samples. Consider ef_construction >= 400.",
|
|
301
392
|
UserWarning
|
|
302
393
|
)
|
|
303
|
-
self.
|
|
394
|
+
if self.distance_metric == 'cosine':
|
|
395
|
+
self.index_ = self.faiss.IndexHNSWFlat(
|
|
396
|
+
dim, self.hnsw_M, self.faiss.METRIC_INNER_PRODUCT
|
|
397
|
+
)
|
|
398
|
+
elif self.distance_metric == 'euclidean':
|
|
399
|
+
self.index_ = self.faiss.IndexHNSWFlat(dim, self.hnsw_M)
|
|
400
|
+
else:
|
|
401
|
+
ft, metric_arg = self._faiss_metric_type(
|
|
402
|
+
self.faiss, self.distance_metric, self.minkowski_p
|
|
403
|
+
)
|
|
404
|
+
self.index_ = self.faiss.IndexHNSWFlat(dim, self.hnsw_M, ft)
|
|
405
|
+
if metric_arg is not None:
|
|
406
|
+
self.index_.metric_arg = metric_arg
|
|
304
407
|
self.index_.hnsw.efConstruction = self.hnsw_efConstruction
|
|
305
408
|
self.index_.hnsw.efSearch = self.hnsw_efSearch
|
|
306
409
|
self.index_.add(X)
|
|
@@ -328,10 +431,14 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
328
431
|
# Inner product on normalised vectors: similarity ∈ [-1, 1].
|
|
329
432
|
# Convert to a proper distance (0 = identical, 2 = opposite).
|
|
330
433
|
distances = 1.0 - scores
|
|
331
|
-
|
|
434
|
+
elif self.distance_metric == 'euclidean':
|
|
332
435
|
distances, indices = self.index_.search(X, k)
|
|
333
436
|
# FAISS returns squared L2; clamp to 0 before sqrt.
|
|
334
437
|
distances = np.sqrt(np.maximum(distances, 0))
|
|
438
|
+
else:
|
|
439
|
+
# All other native metrics (manhattan, chebyshev, minkowski, canberra,
|
|
440
|
+
# braycurtis, jensenshannon) are returned as proper distances already.
|
|
441
|
+
distances, indices = self.index_.search(X, k)
|
|
335
442
|
|
|
336
443
|
return distances.astype(np.float32), indices
|
|
337
444
|
|
|
@@ -340,8 +447,13 @@ class AnnoyNeighborFinder(NeighborFinder):
|
|
|
340
447
|
"""
|
|
341
448
|
Approximate nearest neighbors via Annoy.
|
|
342
449
|
|
|
343
|
-
Supports: euclidean, manhattan, cosine, hamming
|
|
450
|
+
Supports: euclidean, manhattan, cosine (stored as 'angular'), hamming,
|
|
451
|
+
and dot (inner product, stored as 'dot').
|
|
344
452
|
chebyshev and minkowski are not available in Annoy — use preset='exact' for those.
|
|
453
|
+
|
|
454
|
+
Note on 'dot': Annoy's dot-product space is not a true metric. Distances
|
|
455
|
+
returned are reduced inner-product values, not raw dot products — see
|
|
456
|
+
Bachrach et al. (2014). Prefer 'cosine' for normalised embeddings.
|
|
345
457
|
"""
|
|
346
458
|
|
|
347
459
|
def __init__(self, k=10, n_trees=100, distance_metric='euclidean', search_k=-1):
|
|
@@ -349,7 +461,8 @@ class AnnoyNeighborFinder(NeighborFinder):
|
|
|
349
461
|
Parameters
|
|
350
462
|
----------
|
|
351
463
|
distance_metric : str
|
|
352
|
-
One of 'euclidean', 'manhattan', 'cosine', 'hamming'
|
|
464
|
+
One of 'euclidean', 'manhattan', 'cosine', 'hamming', 'dot'.
|
|
465
|
+
Default: 'euclidean'.
|
|
353
466
|
"""
|
|
354
467
|
if k <= 0:
|
|
355
468
|
raise ValueError(f"k must be positive, got k={k}")
|
|
@@ -440,26 +553,55 @@ class HNSWNeighborFinder(NeighborFinder):
|
|
|
440
553
|
"""
|
|
441
554
|
Approximate nearest neighbors via HNSW (hnswlib or nmslib backend).
|
|
442
555
|
|
|
443
|
-
|
|
444
|
-
|
|
556
|
+
Native metric support depends on the backend:
|
|
557
|
+
|
|
558
|
+
hnswlib
|
|
559
|
+
Supports 'euclidean' (l2), 'cosine', and 'dot' (ip / inner product).
|
|
560
|
+
All other metrics raise an error — use preset='exact' instead.
|
|
561
|
+
|
|
562
|
+
nmslib
|
|
563
|
+
Supports 'euclidean' (l2), 'cosine' (cosinesimil), 'manhattan' (l1),
|
|
564
|
+
'chebyshev' (linf), and 'dot' (negdotprod / max inner-product search).
|
|
565
|
+
All other metrics raise an error — use preset='exact' instead.
|
|
566
|
+
|
|
567
|
+
Note on 'dot': inner product is not a true distance metric. Results are
|
|
568
|
+
ranked by descending similarity, not ascending distance. Use 'cosine' for
|
|
569
|
+
normalised embeddings where you want a proper distance.
|
|
445
570
|
"""
|
|
446
571
|
|
|
572
|
+
# Per-backend accepted metrics (validated in __init__).
|
|
573
|
+
_HNSWLIB_METRICS = set(_HNSWLIB_METRIC_MAP) # euclidean, cosine, dot
|
|
574
|
+
_NMSLIB_METRICS = set(_NMSLIB_METRIC_MAP) # euclidean, cosine, manhattan, chebyshev, dot
|
|
575
|
+
|
|
447
576
|
def __init__(self, k=10, M=32, ef_construction=400,
|
|
448
577
|
ef_search=200, backend='hnswlib', distance_metric='euclidean'):
|
|
449
578
|
"""
|
|
450
579
|
Parameters
|
|
451
580
|
----------
|
|
452
581
|
distance_metric : str
|
|
453
|
-
'euclidean'
|
|
454
|
-
|
|
582
|
+
hnswlib: 'euclidean', 'cosine', or 'dot'.
|
|
583
|
+
nmslib: 'euclidean', 'cosine', 'manhattan', 'chebyshev', or 'dot'.
|
|
584
|
+
Default: 'euclidean'.
|
|
585
|
+
backend : str
|
|
586
|
+
'hnswlib' (default) or 'nmslib'.
|
|
455
587
|
"""
|
|
456
588
|
if k <= 0:
|
|
457
589
|
raise ValueError(f"k must be positive, got k={k}")
|
|
458
590
|
metric = distance_metric.lower()
|
|
459
|
-
|
|
591
|
+
backend_str = backend.lower()
|
|
592
|
+
|
|
593
|
+
if backend_str == 'hnswlib':
|
|
594
|
+
allowed = self._HNSWLIB_METRICS
|
|
595
|
+
elif backend_str == 'nmslib':
|
|
596
|
+
allowed = self._NMSLIB_METRICS
|
|
597
|
+
else:
|
|
598
|
+
raise ValueError(f"Unknown backend: '{backend}'. Choose 'hnswlib' or 'nmslib'.")
|
|
599
|
+
|
|
600
|
+
if metric not in allowed:
|
|
460
601
|
raise ValueError(
|
|
461
|
-
f"distance_metric='{distance_metric}' is not
|
|
462
|
-
f"HNSWNeighborFinder
|
|
602
|
+
f"distance_metric='{distance_metric}' is not supported by "
|
|
603
|
+
f"HNSWNeighborFinder (backend='{backend_str}'). "
|
|
604
|
+
f"Available: {sorted(allowed)}. "
|
|
463
605
|
f"For other metrics use preset='exact' (KNNNeighborFinder)."
|
|
464
606
|
)
|
|
465
607
|
self.n_neighbors = k
|
|
@@ -467,7 +609,7 @@ class HNSWNeighborFinder(NeighborFinder):
|
|
|
467
609
|
self.M = M
|
|
468
610
|
self.ef_construction = ef_construction
|
|
469
611
|
self.ef_search = ef_search
|
|
470
|
-
self.backend =
|
|
612
|
+
self.backend = backend_str
|
|
471
613
|
self.index_ = None
|
|
472
614
|
self._check_availability()
|
|
473
615
|
|
|
@@ -504,7 +646,7 @@ class HNSWNeighborFinder(NeighborFinder):
|
|
|
504
646
|
)
|
|
505
647
|
|
|
506
648
|
if self.backend == 'hnswlib':
|
|
507
|
-
space =
|
|
649
|
+
space = _HNSWLIB_METRIC_MAP[self.distance_metric]
|
|
508
650
|
self.index_ = self.hnswlib.Index(space=space, dim=dim)
|
|
509
651
|
self.index_.init_index(
|
|
510
652
|
max_elements=n_samples, M=self.M, ef_construction=self.ef_construction
|
|
@@ -513,13 +655,10 @@ class HNSWNeighborFinder(NeighborFinder):
|
|
|
513
655
|
self.index_.add_items(X, np.arange(n_samples))
|
|
514
656
|
|
|
515
657
|
else: # nmslib
|
|
516
|
-
|
|
517
|
-
'euclidean': 'l2',
|
|
518
|
-
'cosine': 'cosinesimil',
|
|
519
|
-
}
|
|
658
|
+
space = _NMSLIB_METRIC_MAP[self.distance_metric]
|
|
520
659
|
self.index_ = self.nmslib.init(
|
|
521
660
|
method='hnsw',
|
|
522
|
-
space=
|
|
661
|
+
space=space,
|
|
523
662
|
data_type=self.nmslib.DataType.DENSE_VECTOR
|
|
524
663
|
)
|
|
525
664
|
self.index_.addDataPointBatch(X)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|