deskit 1.1.1__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {deskit-1.1.1/src/deskit.egg-info → deskit-1.2.0}/PKG-INFO +1 -1
- {deskit-1.1.1 → deskit-1.2.0}/pyproject.toml +1 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/_config.py +14 -3
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/dewsi.py +5 -2
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/dewsiv.py +5 -2
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/dewst.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/dewsu.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/dewsv.py +5 -2
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/knorae.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/knoraiu.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/knorau.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/lwsei.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/lwseu.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/ola.py +4 -1
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/neighbors.py +248 -32
- {deskit-1.1.1 → deskit-1.2.0/src/deskit.egg-info}/PKG-INFO +1 -1
- {deskit-1.1.1 → deskit-1.2.0}/LICENSE +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/README.md +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/setup.cfg +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/__init__.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/base/__init__.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/base/base.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/base/knnbase.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/base/predictbase.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/des/__init__.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/metrics.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/router.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit/utils.py +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit.egg-info/SOURCES.txt +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit.egg-info/dependency_links.txt +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit.egg-info/requires.txt +0 -0
- {deskit-1.1.1 → deskit-1.2.0}/src/deskit.egg-info/top_level.txt +0 -0
|
@@ -88,7 +88,7 @@ def resolve_metric(metric):
|
|
|
88
88
|
# Neighbor finder construction
|
|
89
89
|
# ---------------------------------------------------------------------------
|
|
90
90
|
|
|
91
|
-
def make_finder(preset, k, finder=None, **kwargs):
|
|
91
|
+
def make_finder(preset, k, finder=None, distance_metric='euclidean', **kwargs):
|
|
92
92
|
"""
|
|
93
93
|
Create a NeighborFinder from a preset name or custom finder string.
|
|
94
94
|
|
|
@@ -100,6 +100,17 @@ def make_finder(preset, k, finder=None, **kwargs):
|
|
|
100
100
|
Number of neighbors.
|
|
101
101
|
finder : str, optional
|
|
102
102
|
Required when preset='custom'. One of 'knn', 'faiss', 'annoy', 'hnsw'.
|
|
103
|
+
distance_metric : str
|
|
104
|
+
Distance function to use. Default: 'euclidean'. See
|
|
105
|
+
neighbors.list_distance_metrics() for all options and per-backend
|
|
106
|
+
availability.
|
|
107
|
+
|
|
108
|
+
Quick guide:
|
|
109
|
+
'euclidean' — Best default for most tabular data.
|
|
110
|
+
'manhattan' — More robust to outliers; good for moderate-to-high dims.
|
|
111
|
+
'chebyshev' — Max abs diff; use preset='exact' (KNN only).
|
|
112
|
+
'minkowski' — L1/L2 generalisation; use preset='exact' (KNN only).
|
|
113
|
+
'cosine' — Direction-based; ideal for embeddings (NLP, vision).
|
|
103
114
|
**kwargs
|
|
104
115
|
Forwarded to the finder constructor (e.g. index_type, n_probes).
|
|
105
116
|
"""
|
|
@@ -107,7 +118,7 @@ def make_finder(preset, k, finder=None, **kwargs):
|
|
|
107
118
|
if finder is None:
|
|
108
119
|
raise ValueError("Must specify 'finder' when using preset='custom'.")
|
|
109
120
|
finder_type = finder.lower()
|
|
110
|
-
finder_kwargs = {'k': k, **kwargs}
|
|
121
|
+
finder_kwargs = {'k': k, 'distance_metric': distance_metric, **kwargs}
|
|
111
122
|
else:
|
|
112
123
|
if preset not in SPEED_PRESETS:
|
|
113
124
|
raise ValueError(
|
|
@@ -117,7 +128,7 @@ def make_finder(preset, k, finder=None, **kwargs):
|
|
|
117
128
|
)
|
|
118
129
|
config = SPEED_PRESETS[preset]
|
|
119
130
|
finder_type = config['finder']
|
|
120
|
-
finder_kwargs = {**config['kwargs'], 'k': k, **kwargs}
|
|
131
|
+
finder_kwargs = {**config['kwargs'], 'k': k, 'distance_metric': distance_metric, **kwargs}
|
|
121
132
|
|
|
122
133
|
if finder_type == 'knn':
|
|
123
134
|
from deskit.neighbors import KNNNeighborFinder
|
|
@@ -35,12 +35,15 @@ class DEWSI(KNNBase):
|
|
|
35
35
|
(min-metrics) and 1.0 for classification (max-metrics) at predict time.
|
|
36
36
|
preset : str
|
|
37
37
|
Neighbor search preset. Default: 'balanced'. See list_presets().
|
|
38
|
+
distance_metric : str
|
|
39
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
40
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
43
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
41
|
-
threshold=0.5, temperature=None, preset='balanced', **kwargs):
|
|
44
|
+
threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidian', **kwargs):
|
|
42
45
|
metric_name, metric_fn = resolve_metric(metric)
|
|
43
|
-
finder = make_finder(preset, k, **kwargs)
|
|
46
|
+
finder = make_finder(preset, k, distance_metric=distance_metric, **kwargs)
|
|
44
47
|
super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
|
|
45
48
|
self.task = task
|
|
46
49
|
self.threshold = threshold
|
|
@@ -37,12 +37,15 @@ class DEWSIV(KNNBase):
|
|
|
37
37
|
Defaults to 0.5 for min-metrics, 1.0 otherwise.
|
|
38
38
|
preset : str
|
|
39
39
|
Neighbour search preset. Default: 'balanced'. See list_presets().
|
|
40
|
+
distance_metric : str
|
|
41
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
42
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
40
43
|
"""
|
|
41
44
|
|
|
42
45
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
43
|
-
threshold=0.5, temperature=None, preset='balanced', **kwargs):
|
|
46
|
+
threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidian', **kwargs):
|
|
44
47
|
metric_name, metric_fn = resolve_metric(metric)
|
|
45
|
-
finder = make_finder(preset, k, **kwargs)
|
|
48
|
+
finder = make_finder(preset, k, distance_metric=distance_metric, **kwargs)
|
|
46
49
|
|
|
47
50
|
self._use_signed = metric_name in _SIGNED_METRICS
|
|
48
51
|
self._metric_name = metric_name
|
|
@@ -40,11 +40,14 @@ class DEWST(KNNBase):
|
|
|
40
40
|
the sample falls back to DEWS-I scoring for that model. Default: 0.7.
|
|
41
41
|
preset : str
|
|
42
42
|
Neighbour search preset. Default: 'balanced'. See list_presets().
|
|
43
|
+
distance_metric : str
|
|
44
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
45
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
43
46
|
"""
|
|
44
47
|
|
|
45
48
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
46
49
|
threshold=0.5, temperature=None, r2_threshold=0.7,
|
|
47
|
-
preset='balanced', **kwargs):
|
|
50
|
+
preset='balanced', distance_metric='euclidean', **kwargs):
|
|
48
51
|
metric_name, metric_fn = resolve_metric(metric)
|
|
49
52
|
finder = make_finder(preset, k, **kwargs)
|
|
50
53
|
|
|
@@ -31,10 +31,13 @@ class DEWSU(KNNBase):
|
|
|
31
31
|
(min-metrics) and 1.0 for classification (max-metrics) at predict time.
|
|
32
32
|
preset : str
|
|
33
33
|
Neighbor search preset. Default: 'balanced'. See list_presets().
|
|
34
|
+
distance_metric : str
|
|
35
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
36
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
34
37
|
"""
|
|
35
38
|
|
|
36
39
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
37
|
-
threshold=0.5, temperature=None, preset='balanced', **kwargs):
|
|
40
|
+
threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
38
41
|
metric_name, metric_fn = resolve_metric(metric)
|
|
39
42
|
finder = make_finder(preset, k, **kwargs)
|
|
40
43
|
super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
|
|
@@ -37,12 +37,15 @@ class DEWSV(KNNBase):
|
|
|
37
37
|
Defaults to 0.5 for min-metrics, 1.0 otherwise.
|
|
38
38
|
preset : str
|
|
39
39
|
Neighbour search preset. Default: 'balanced'. See list_presets().
|
|
40
|
+
distance_metric : str
|
|
41
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
42
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
40
43
|
"""
|
|
41
44
|
|
|
42
45
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
43
|
-
threshold=0.5, temperature=None, preset='balanced', **kwargs):
|
|
46
|
+
threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
44
47
|
metric_name, metric_fn = resolve_metric(metric)
|
|
45
|
-
finder = make_finder(preset, k, **kwargs)
|
|
48
|
+
finder = make_finder(preset, k, distance_metric=distance_metric, **kwargs)
|
|
46
49
|
|
|
47
50
|
self._use_signed = metric_name in _SIGNED_METRICS
|
|
48
51
|
self._metric_name = metric_name
|
|
@@ -26,10 +26,13 @@ class KNORAE(KNNBase):
|
|
|
26
26
|
Regression: use 1.0.
|
|
27
27
|
preset : str
|
|
28
28
|
Neighbor search preset. Default: 'balanced'. See list_presets().
|
|
29
|
+
distance_metric : str
|
|
30
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
31
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
29
32
|
"""
|
|
30
33
|
|
|
31
34
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
32
|
-
threshold=0.5, preset='balanced', **kwargs):
|
|
35
|
+
threshold=0.5, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
33
36
|
metric_name, metric_fn = resolve_metric(metric)
|
|
34
37
|
finder = make_finder(preset, k, **kwargs)
|
|
35
38
|
super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
|
|
@@ -27,10 +27,13 @@ class KNORAIU(KNNBase):
|
|
|
27
27
|
Regression: use 1.0.
|
|
28
28
|
preset : str
|
|
29
29
|
Neighbor search preset. Default: 'balanced'. See list_presets().
|
|
30
|
+
distance_metric : str
|
|
31
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
32
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
30
33
|
"""
|
|
31
34
|
|
|
32
35
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
33
|
-
threshold=0.5, preset='balanced', **kwargs):
|
|
36
|
+
threshold=0.5, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
34
37
|
metric_name, metric_fn = resolve_metric(metric)
|
|
35
38
|
finder = make_finder(preset, k, **kwargs)
|
|
36
39
|
super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
|
|
@@ -27,10 +27,13 @@ class KNORAU(KNNBase):
|
|
|
27
27
|
Regression: use 1.0.
|
|
28
28
|
preset : str
|
|
29
29
|
Neighbor search preset. Default: 'balanced'. See list_presets().
|
|
30
|
+
distance_metric : str
|
|
31
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
32
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
30
33
|
"""
|
|
31
34
|
|
|
32
35
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
33
|
-
threshold=0.5, preset='balanced', **kwargs):
|
|
36
|
+
threshold=0.5, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
34
37
|
metric_name, metric_fn = resolve_metric(metric)
|
|
35
38
|
finder = make_finder(preset, k, **kwargs)
|
|
36
39
|
super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
|
|
@@ -19,9 +19,12 @@ class LWSEI(PredictBase):
|
|
|
19
19
|
Neighbourhood size. Default: 10.
|
|
20
20
|
preset : str
|
|
21
21
|
Neighbour search preset. Default: 'balanced'. See list_presets().
|
|
22
|
+
distance_metric : str
|
|
23
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
24
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
22
25
|
"""
|
|
23
26
|
|
|
24
|
-
def __init__(self, task, k=10, preset='balanced', **kwargs):
|
|
27
|
+
def __init__(self, task, k=10, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
25
28
|
self.task = task
|
|
26
29
|
self.k = k
|
|
27
30
|
self._finder = make_finder(preset, k, **kwargs)
|
|
@@ -19,9 +19,12 @@ class LWSEU(PredictBase):
|
|
|
19
19
|
Neighbourhood size. Default: 10.
|
|
20
20
|
preset : str
|
|
21
21
|
Neighbour search preset. Default: 'balanced'. See list_presets().
|
|
22
|
+
distance_metric : str
|
|
23
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
24
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
22
25
|
"""
|
|
23
26
|
|
|
24
|
-
def __init__(self, task, k=10, preset='balanced', **kwargs):
|
|
27
|
+
def __init__(self, task, k=10, preset='balanced', distance_metric='euclidean', **kwargs):
|
|
25
28
|
self.task = task
|
|
26
29
|
self.k = k
|
|
27
30
|
self._finder = make_finder(preset, k, **kwargs)
|
|
@@ -22,10 +22,13 @@ class OLA(KNNBase):
|
|
|
22
22
|
Neighborhood size. Default: 10.
|
|
23
23
|
preset : str
|
|
24
24
|
Neighbor search preset. Default: 'balanced'. See list_presets().
|
|
25
|
+
distance_metric : str
|
|
26
|
+
Distance function to use for neighbor search. Default: 'euclidean'. See
|
|
27
|
+
neighbors.list_distance_metrics() for all options and per-backend availability.
|
|
25
28
|
"""
|
|
26
29
|
|
|
27
30
|
def __init__(self, task, metric='mae', mode='min', k=10,
|
|
28
|
-
preset='balanced', threshold=None, **kwargs):
|
|
31
|
+
preset='balanced', threshold=None, distance_metric='euclidean', **kwargs):
|
|
29
32
|
metric_name, metric_fn = resolve_metric(metric)
|
|
30
33
|
finder = make_finder(preset, k, **kwargs)
|
|
31
34
|
super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
|
|
@@ -5,6 +5,82 @@ import warnings
|
|
|
5
5
|
_FAISS_MIN_SAMPLES_PER_CELL = 40
|
|
6
6
|
|
|
7
7
|
|
|
8
|
+
# ---------------------------------------------------------------------------
|
|
9
|
+
# Distance metric registry
|
|
10
|
+
# ---------------------------------------------------------------------------
|
|
11
|
+
|
|
12
|
+
# Metrics supported by each backend.
|
|
13
|
+
# 'euclidean' is the universal default and always available.
|
|
14
|
+
#
|
|
15
|
+
# Choosing a distance metric:
|
|
16
|
+
# euclidean – The standard L2 norm. Best default for most tabular data.
|
|
17
|
+
# manhattan – L1 norm (sum of absolute differences). More robust to outliers
|
|
18
|
+
# and tends to work better in moderately high-dimensional spaces
|
|
19
|
+
# because it doesn't square large differences.
|
|
20
|
+
# chebyshev – L∞ norm (maximum absolute difference across features). Useful
|
|
21
|
+
# when a single feature dominating the distance is acceptable;
|
|
22
|
+
# common in game-grid / chess-style distance problems.
|
|
23
|
+
# minkowski – Generalisation of L1/L2 (controlled by p). p=1 → manhattan,
|
|
24
|
+
# p=2 → euclidean. Use when you want to tune between them.
|
|
25
|
+
# cosine – Angle between vectors, ignoring magnitude. Excellent for
|
|
26
|
+
# embeddings (text, image, audio) where direction matters more
|
|
27
|
+
# than raw scale.
|
|
28
|
+
|
|
29
|
+
# Metrics that every backend supports natively.
|
|
30
|
+
_UNIVERSAL_METRICS = {'euclidean', 'manhattan', 'chebyshev', 'minkowski', 'cosine'}
|
|
31
|
+
|
|
32
|
+
# Per-backend metric availability.
|
|
33
|
+
# KNN (sklearn) supports all scipy metrics — this is the complete curated list.
|
|
34
|
+
_KNN_METRICS = _UNIVERSAL_METRICS | {'correlation', 'hamming', 'canberra', 'braycurtis'}
|
|
35
|
+
|
|
36
|
+
# FAISS only has built-in L2 and inner-product (cosine via normalization).
|
|
37
|
+
# All others fall back to a manual compute-then-search path.
|
|
38
|
+
_FAISS_NATIVE_METRICS = {'euclidean', 'cosine'}
|
|
39
|
+
_FAISS_METRICS = _UNIVERSAL_METRICS # remainder handled via sklearn fallback
|
|
40
|
+
|
|
41
|
+
# Annoy metric names (library-specific).
|
|
42
|
+
_ANNOY_METRIC_MAP = {
|
|
43
|
+
'euclidean': 'euclidean',
|
|
44
|
+
'manhattan': 'manhattan',
|
|
45
|
+
'cosine': 'angular',
|
|
46
|
+
'hamming': 'hamming',
|
|
47
|
+
# chebyshev and minkowski are not natively supported; we warn the user.
|
|
48
|
+
}
|
|
49
|
+
_ANNOY_METRICS = set(_ANNOY_METRIC_MAP)
|
|
50
|
+
|
|
51
|
+
# HNSW (hnswlib) space names.
|
|
52
|
+
_HNSW_METRIC_MAP = {
|
|
53
|
+
'euclidean': 'l2',
|
|
54
|
+
'cosine': 'cosine',
|
|
55
|
+
# Others not natively supported in hnswlib; we warn and fall back to l2.
|
|
56
|
+
}
|
|
57
|
+
_HNSW_METRICS = _UNIVERSAL_METRICS # partial — see fit() for fallback note
|
|
58
|
+
|
|
59
|
+
# All metrics callable from the public API.
|
|
60
|
+
ALL_METRICS = _KNN_METRICS
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def list_distance_metrics():
|
|
64
|
+
"""Print all available distance metrics with per-backend availability."""
|
|
65
|
+
print("\nAvailable Distance Metrics:")
|
|
66
|
+
print("=" * 70)
|
|
67
|
+
rows = [
|
|
68
|
+
("euclidean", "Default. L2 norm. Best for most tabular data.", "all"),
|
|
69
|
+
("manhattan", "L1 norm. More robust to outliers; good for high-dim data.", "all"),
|
|
70
|
+
("chebyshev", "L∞ norm. Max absolute diff across features.", "KNN only (exact preset)"),
|
|
71
|
+
("minkowski", "Generalises L1/L2 via p-param. Set minkowski_p=<float>.", "KNN only (exact preset)"),
|
|
72
|
+
("cosine", "Angle between vectors. Ideal for embeddings (NLP, vision).", "all"),
|
|
73
|
+
("correlation","Pearson correlation distance. Good for time series.", "KNN only (exact preset)"),
|
|
74
|
+
("hamming", "Fraction of differing components. For binary/categorical data.", "KNN, Annoy"),
|
|
75
|
+
("canberra", "Weighted L1. Sensitive to small values near zero.", "KNN only (exact preset)"),
|
|
76
|
+
("braycurtis", "Normalised L1 bounded to [0,1]. Ecological data.", "KNN only (exact preset)"),
|
|
77
|
+
]
|
|
78
|
+
for name, desc, backends in rows:
|
|
79
|
+
print(f"\n {name:<14} {desc}")
|
|
80
|
+
print(f" {'':14} Backends: {backends}")
|
|
81
|
+
print("\n" + "=" * 70)
|
|
82
|
+
|
|
83
|
+
|
|
8
84
|
class NeighborFinder:
|
|
9
85
|
"""Base class for neighbor search backends."""
|
|
10
86
|
|
|
@@ -16,12 +92,37 @@ class NeighborFinder:
|
|
|
16
92
|
|
|
17
93
|
|
|
18
94
|
class KNNNeighborFinder(NeighborFinder):
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
|
|
95
|
+
"""
|
|
96
|
+
Exact nearest neighbors via sklearn NearestNeighbors.
|
|
97
|
+
|
|
98
|
+
Supports all distance metrics in deskit (euclidean, manhattan, chebyshev,
|
|
99
|
+
minkowski, cosine, correlation, hamming, canberra, braycurtis).
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, k=10, distance_metric='euclidean', minkowski_p=2, **kwargs):
|
|
103
|
+
"""
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
k : int
|
|
107
|
+
Number of neighbors.
|
|
108
|
+
distance_metric : str
|
|
109
|
+
Distance function to use. One of the metrics returned by
|
|
110
|
+
list_distance_metrics(). Default: 'euclidean'.
|
|
111
|
+
minkowski_p : float
|
|
112
|
+
The p-parameter for the Minkowski metric (p=1 → manhattan,
|
|
113
|
+
p=2 → euclidean). Ignored for all other metrics.
|
|
114
|
+
"""
|
|
22
115
|
if k <= 0:
|
|
23
116
|
raise ValueError(f"k must be positive, got k={k}")
|
|
117
|
+
metric = distance_metric.lower()
|
|
118
|
+
if metric not in _KNN_METRICS:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"distance_metric='{distance_metric}' is not supported by KNNNeighborFinder. "
|
|
121
|
+
f"Available: {sorted(_KNN_METRICS)}."
|
|
122
|
+
)
|
|
24
123
|
self.n_neighbors = k
|
|
124
|
+
self.distance_metric = metric
|
|
125
|
+
self.minkowski_p = minkowski_p
|
|
25
126
|
self.kwargs = kwargs
|
|
26
127
|
self.model = None
|
|
27
128
|
|
|
@@ -33,7 +134,15 @@ class KNNNeighborFinder(NeighborFinder):
|
|
|
33
134
|
f"Cannot find {self.n_neighbors} neighbors in a dataset with only "
|
|
34
135
|
f"{X.shape[0]} samples. Reduce k to at most {X.shape[0]}."
|
|
35
136
|
)
|
|
36
|
-
|
|
137
|
+
metric_kwargs = {}
|
|
138
|
+
if self.distance_metric == 'minkowski':
|
|
139
|
+
metric_kwargs['p'] = self.minkowski_p
|
|
140
|
+
self.model = NearestNeighbors(
|
|
141
|
+
n_neighbors=self.n_neighbors,
|
|
142
|
+
metric=self.distance_metric,
|
|
143
|
+
metric_params=metric_kwargs if metric_kwargs else None,
|
|
144
|
+
**self.kwargs,
|
|
145
|
+
)
|
|
37
146
|
self.model.fit(X)
|
|
38
147
|
return self
|
|
39
148
|
|
|
@@ -48,12 +157,34 @@ class KNNNeighborFinder(NeighborFinder):
|
|
|
48
157
|
|
|
49
158
|
|
|
50
159
|
class FaissNeighborFinder(NeighborFinder):
|
|
51
|
-
"""
|
|
160
|
+
"""
|
|
161
|
+
Approximate nearest neighbors via FAISS (flat, IVF, or HNSW index).
|
|
162
|
+
|
|
163
|
+
Natively supports 'euclidean' and 'cosine'. All other metrics in
|
|
164
|
+
_UNIVERSAL_METRICS fall back to a sklearn-based exact search with a
|
|
165
|
+
warning, so you can still use them without switching presets.
|
|
166
|
+
"""
|
|
52
167
|
|
|
53
168
|
def __init__(self, k=10, index_type='flat', n_cells=None, n_probes=50,
|
|
54
|
-
hnsw_M=32, hnsw_efConstruction=400, hnsw_efSearch=200
|
|
169
|
+
hnsw_M=32, hnsw_efConstruction=400, hnsw_efSearch=200,
|
|
170
|
+
distance_metric='euclidean'):
|
|
171
|
+
"""
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
distance_metric : str
|
|
175
|
+
'euclidean' (default) or 'cosine'. Other metrics fall back to
|
|
176
|
+
exact sklearn search with a warning — use preset='exact' to avoid
|
|
177
|
+
the overhead.
|
|
178
|
+
"""
|
|
55
179
|
if k <= 0:
|
|
56
180
|
raise ValueError(f"k must be positive, got k={k}")
|
|
181
|
+
metric = distance_metric.lower()
|
|
182
|
+
if metric not in _FAISS_METRICS:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"distance_metric='{distance_metric}' is not supported by FaissNeighborFinder. "
|
|
185
|
+
f"Available: {sorted(_FAISS_METRICS)}. "
|
|
186
|
+
f"For other metrics use preset='exact' (KNNNeighborFinder)."
|
|
187
|
+
)
|
|
57
188
|
self.n_neighbors = k
|
|
58
189
|
self.index_type = index_type.lower()
|
|
59
190
|
self.n_cells = n_cells
|
|
@@ -61,7 +192,9 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
61
192
|
self.hnsw_M = hnsw_M
|
|
62
193
|
self.hnsw_efConstruction = hnsw_efConstruction
|
|
63
194
|
self.hnsw_efSearch = hnsw_efSearch
|
|
195
|
+
self.distance_metric = metric
|
|
64
196
|
self.index_ = None
|
|
197
|
+
self._fallback_finder = None # used for non-native metrics
|
|
65
198
|
self._check_availability()
|
|
66
199
|
|
|
67
200
|
def _check_availability(self):
|
|
@@ -71,6 +204,13 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
71
204
|
except ImportError:
|
|
72
205
|
raise ImportError("FAISS not found. Install with: pip install faiss-cpu")
|
|
73
206
|
|
|
207
|
+
@staticmethod
|
|
208
|
+
def _l2_normalize(X):
|
|
209
|
+
"""Row-wise L2 normalisation in-place (for cosine similarity)."""
|
|
210
|
+
norms = np.linalg.norm(X, axis=1, keepdims=True)
|
|
211
|
+
norms = np.where(norms == 0, 1.0, norms)
|
|
212
|
+
return X / norms.astype(np.float32)
|
|
213
|
+
|
|
74
214
|
def fit(self, X):
|
|
75
215
|
X = np.atleast_2d(X).astype(np.float32)
|
|
76
216
|
n_samples, dim = X.shape
|
|
@@ -81,6 +221,24 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
81
221
|
f"{n_samples} samples. Reduce k to at most {n_samples}."
|
|
82
222
|
)
|
|
83
223
|
|
|
224
|
+
# Non-native metrics: delegate entirely to KNNNeighborFinder.
|
|
225
|
+
if self.distance_metric not in _FAISS_NATIVE_METRICS:
|
|
226
|
+
warnings.warn(
|
|
227
|
+
f"distance_metric='{self.distance_metric}' is not natively supported by "
|
|
228
|
+
f"FAISS. Falling back to exact sklearn KNN for this metric. "
|
|
229
|
+
f"Use preset='exact' to avoid this overhead.",
|
|
230
|
+
UserWarning,
|
|
231
|
+
)
|
|
232
|
+
self._fallback_finder = KNNNeighborFinder(
|
|
233
|
+
k=self.n_neighbors, distance_metric=self.distance_metric
|
|
234
|
+
)
|
|
235
|
+
self._fallback_finder.fit(X)
|
|
236
|
+
return self
|
|
237
|
+
|
|
238
|
+
# Cosine similarity: normalise all vectors, then use inner-product index.
|
|
239
|
+
if self.distance_metric == 'cosine':
|
|
240
|
+
X = self._l2_normalize(X)
|
|
241
|
+
|
|
84
242
|
if self.index_type == 'flat':
|
|
85
243
|
if dim <= 2:
|
|
86
244
|
warnings.warn(
|
|
@@ -88,14 +246,16 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
88
246
|
f"Consider KNNNeighborFinder for low-dimensional data.",
|
|
89
247
|
UserWarning
|
|
90
248
|
)
|
|
91
|
-
self.
|
|
249
|
+
if self.distance_metric == 'cosine':
|
|
250
|
+
self.index_ = self.faiss.IndexFlatIP(dim) # inner product on normalised vecs
|
|
251
|
+
else:
|
|
252
|
+
self.index_ = self.faiss.IndexFlatL2(dim)
|
|
92
253
|
self.index_.add(X)
|
|
93
254
|
|
|
94
255
|
elif self.index_type == 'ivf':
|
|
95
256
|
if self.n_cells is None:
|
|
96
257
|
self.n_cells = min(int(np.sqrt(n_samples)), 4096)
|
|
97
258
|
|
|
98
|
-
# Reduce n_cells if dataset is too small
|
|
99
259
|
min_required = self.n_cells * _FAISS_MIN_SAMPLES_PER_CELL
|
|
100
260
|
if n_samples < min_required:
|
|
101
261
|
safe_cells = max(1, n_samples // _FAISS_MIN_SAMPLES_PER_CELL)
|
|
@@ -121,8 +281,14 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
121
281
|
UserWarning
|
|
122
282
|
)
|
|
123
283
|
|
|
124
|
-
|
|
125
|
-
|
|
284
|
+
if self.distance_metric == 'cosine':
|
|
285
|
+
quantizer = self.faiss.IndexFlatIP(dim)
|
|
286
|
+
self.index_ = self.faiss.IndexIVFFlat(
|
|
287
|
+
quantizer, dim, self.n_cells, self.faiss.METRIC_INNER_PRODUCT
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
quantizer = self.faiss.IndexFlatL2(dim)
|
|
291
|
+
self.index_ = self.faiss.IndexIVFFlat(quantizer, dim, self.n_cells)
|
|
126
292
|
self.index_.train(X)
|
|
127
293
|
self.index_.add(X)
|
|
128
294
|
self.index_.nprobe = effective_probes
|
|
@@ -151,21 +317,53 @@ class FaissNeighborFinder(NeighborFinder):
|
|
|
151
317
|
X = np.atleast_2d(X).astype(np.float32)
|
|
152
318
|
if X.shape[0] == 0:
|
|
153
319
|
return np.empty((0, k), dtype=np.float32), np.empty((0, k), dtype=np.int64)
|
|
154
|
-
distances, indices = self.index_.search(X, k)
|
|
155
|
-
# FAISS returns squared L2; clamp to 0 before sqrt
|
|
156
|
-
return np.sqrt(np.maximum(distances, 0)), indices
|
|
157
320
|
|
|
321
|
+
# Non-native metric fallback.
|
|
322
|
+
if self._fallback_finder is not None:
|
|
323
|
+
return self._fallback_finder.kneighbors(X, k=k)
|
|
324
|
+
|
|
325
|
+
if self.distance_metric == 'cosine':
|
|
326
|
+
X = self._l2_normalize(X)
|
|
327
|
+
scores, indices = self.index_.search(X, k)
|
|
328
|
+
# Inner product on normalised vectors: similarity ∈ [-1, 1].
|
|
329
|
+
# Convert to a proper distance (0 = identical, 2 = opposite).
|
|
330
|
+
distances = 1.0 - scores
|
|
331
|
+
else:
|
|
332
|
+
distances, indices = self.index_.search(X, k)
|
|
333
|
+
# FAISS returns squared L2; clamp to 0 before sqrt.
|
|
334
|
+
distances = np.sqrt(np.maximum(distances, 0))
|
|
335
|
+
|
|
336
|
+
return distances.astype(np.float32), indices
|
|
158
337
|
|
|
159
|
-
class AnnoyNeighborFinder(NeighborFinder):
|
|
160
|
-
"""Approximate nearest neighbors via Annoy."""
|
|
161
338
|
|
|
162
|
-
|
|
339
|
+
class AnnoyNeighborFinder(NeighborFinder):
|
|
340
|
+
"""
|
|
341
|
+
Approximate nearest neighbors via Annoy.
|
|
342
|
+
|
|
343
|
+
Supports: euclidean, manhattan, cosine, hamming.
|
|
344
|
+
chebyshev and minkowski are not available in Annoy — use preset='exact' for those.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, k=10, n_trees=100, distance_metric='euclidean', search_k=-1):
|
|
348
|
+
"""
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
distance_metric : str
|
|
352
|
+
One of 'euclidean', 'manhattan', 'cosine', 'hamming'. Default: 'euclidean'.
|
|
353
|
+
"""
|
|
163
354
|
if k <= 0:
|
|
164
355
|
raise ValueError(f"k must be positive, got k={k}")
|
|
356
|
+
metric = distance_metric.lower()
|
|
357
|
+
if metric not in _ANNOY_METRICS:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"distance_metric='{distance_metric}' is not supported by AnnoyNeighborFinder. "
|
|
360
|
+
f"Available: {sorted(_ANNOY_METRICS)}. "
|
|
361
|
+
f"For chebyshev or minkowski use preset='exact' (KNNNeighborFinder)."
|
|
362
|
+
)
|
|
165
363
|
self.k = k
|
|
166
364
|
self.n_trees = n_trees
|
|
167
|
-
self.
|
|
168
|
-
# Annoy's recommended default
|
|
365
|
+
self.distance_metric = metric
|
|
366
|
+
# Annoy's recommended default
|
|
169
367
|
self.search_k = n_trees * k if search_k == -1 else search_k
|
|
170
368
|
self.index_ = None
|
|
171
369
|
self.n_samples_ = None
|
|
@@ -195,12 +393,7 @@ class AnnoyNeighborFinder(NeighborFinder):
|
|
|
195
393
|
UserWarning
|
|
196
394
|
)
|
|
197
395
|
|
|
198
|
-
|
|
199
|
-
'euclidean': 'euclidean', 'l2': 'euclidean',
|
|
200
|
-
'angular': 'angular', 'cosine': 'angular',
|
|
201
|
-
'manhattan': 'manhattan', 'hamming': 'hamming', 'dot': 'dot',
|
|
202
|
-
}
|
|
203
|
-
self.index_ = self.AnnoyIndex(dim, metric_map.get(self.metric.lower(), 'euclidean'))
|
|
396
|
+
self.index_ = self.AnnoyIndex(dim, _ANNOY_METRIC_MAP[self.distance_metric])
|
|
204
397
|
for i, vec in enumerate(X):
|
|
205
398
|
self.index_.add_item(i, vec.tolist())
|
|
206
399
|
self.index_.build(self.n_trees)
|
|
@@ -244,14 +437,33 @@ class AnnoyNeighborFinder(NeighborFinder):
|
|
|
244
437
|
|
|
245
438
|
|
|
246
439
|
class HNSWNeighborFinder(NeighborFinder):
|
|
247
|
-
"""
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
440
|
+
"""
|
|
441
|
+
Approximate nearest neighbors via HNSW (hnswlib or nmslib backend).
|
|
442
|
+
|
|
443
|
+
Natively supports 'euclidean' and 'cosine'. Manhattan, chebyshev, and
|
|
444
|
+
minkowski are not available in hnswlib/nmslib — use preset='exact' for those.
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
def __init__(self, k=10, M=32, ef_construction=400,
|
|
448
|
+
ef_search=200, backend='hnswlib', distance_metric='euclidean'):
|
|
449
|
+
"""
|
|
450
|
+
Parameters
|
|
451
|
+
----------
|
|
452
|
+
distance_metric : str
|
|
453
|
+
'euclidean' (default) or 'cosine'. Other metrics are not supported
|
|
454
|
+
natively and will raise an error — use preset='exact' instead.
|
|
455
|
+
"""
|
|
251
456
|
if k <= 0:
|
|
252
457
|
raise ValueError(f"k must be positive, got k={k}")
|
|
458
|
+
metric = distance_metric.lower()
|
|
459
|
+
if metric not in _HNSW_METRIC_MAP:
|
|
460
|
+
raise ValueError(
|
|
461
|
+
f"distance_metric='{distance_metric}' is not natively supported by "
|
|
462
|
+
f"HNSWNeighborFinder. Available: {sorted(_HNSW_METRIC_MAP)}. "
|
|
463
|
+
f"For other metrics use preset='exact' (KNNNeighborFinder)."
|
|
464
|
+
)
|
|
253
465
|
self.n_neighbors = k
|
|
254
|
-
self.
|
|
466
|
+
self.distance_metric = metric
|
|
255
467
|
self.M = M
|
|
256
468
|
self.ef_construction = ef_construction
|
|
257
469
|
self.ef_search = ef_search
|
|
@@ -292,7 +504,8 @@ class HNSWNeighborFinder(NeighborFinder):
|
|
|
292
504
|
)
|
|
293
505
|
|
|
294
506
|
if self.backend == 'hnswlib':
|
|
295
|
-
|
|
507
|
+
space = _HNSW_METRIC_MAP[self.distance_metric]
|
|
508
|
+
self.index_ = self.hnswlib.Index(space=space, dim=dim)
|
|
296
509
|
self.index_.init_index(
|
|
297
510
|
max_elements=n_samples, M=self.M, ef_construction=self.ef_construction
|
|
298
511
|
)
|
|
@@ -300,10 +513,13 @@ class HNSWNeighborFinder(NeighborFinder):
|
|
|
300
513
|
self.index_.add_items(X, np.arange(n_samples))
|
|
301
514
|
|
|
302
515
|
else: # nmslib
|
|
303
|
-
|
|
516
|
+
nmslib_space_map = {
|
|
517
|
+
'euclidean': 'l2',
|
|
518
|
+
'cosine': 'cosinesimil',
|
|
519
|
+
}
|
|
304
520
|
self.index_ = self.nmslib.init(
|
|
305
521
|
method='hnsw',
|
|
306
|
-
space=
|
|
522
|
+
space=nmslib_space_map.get(self.distance_metric, 'l2'),
|
|
307
523
|
data_type=self.nmslib.DataType.DENSE_VECTOR
|
|
308
524
|
)
|
|
309
525
|
self.index_.addDataPointBatch(X)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|