deskit 1.1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {deskit-1.1.0/src/deskit.egg-info → deskit-1.2.0}/PKG-INFO +1 -1
  2. {deskit-1.1.0 → deskit-1.2.0}/pyproject.toml +1 -1
  3. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/_config.py +14 -3
  4. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/dewsi.py +5 -2
  5. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/dewsiv.py +5 -2
  6. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/dewst.py +4 -1
  7. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/dewsu.py +4 -1
  8. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/dewsv.py +5 -2
  9. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/knorae.py +4 -1
  10. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/knoraiu.py +4 -1
  11. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/knorau.py +4 -1
  12. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/lwsei.py +4 -1
  13. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/lwseu.py +4 -1
  14. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/ola.py +5 -2
  15. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/neighbors.py +248 -32
  16. {deskit-1.1.0 → deskit-1.2.0/src/deskit.egg-info}/PKG-INFO +1 -1
  17. {deskit-1.1.0 → deskit-1.2.0}/LICENSE +0 -0
  18. {deskit-1.1.0 → deskit-1.2.0}/README.md +0 -0
  19. {deskit-1.1.0 → deskit-1.2.0}/setup.cfg +0 -0
  20. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/__init__.py +0 -0
  21. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/base/__init__.py +0 -0
  22. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/base/base.py +0 -0
  23. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/base/knnbase.py +0 -0
  24. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/base/predictbase.py +0 -0
  25. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/des/__init__.py +0 -0
  26. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/metrics.py +0 -0
  27. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/router.py +0 -0
  28. {deskit-1.1.0 → deskit-1.2.0}/src/deskit/utils.py +0 -0
  29. {deskit-1.1.0 → deskit-1.2.0}/src/deskit.egg-info/SOURCES.txt +0 -0
  30. {deskit-1.1.0 → deskit-1.2.0}/src/deskit.egg-info/dependency_links.txt +0 -0
  31. {deskit-1.1.0 → deskit-1.2.0}/src/deskit.egg-info/requires.txt +0 -0
  32. {deskit-1.1.0 → deskit-1.2.0}/src/deskit.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deskit
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: A Python library for Dynamic Ensemble Selection
5
5
  Author: Tikhon Vodyanov
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "deskit"
7
- version = "1.1.0"
7
+ version = "1.2.0"
8
8
  description = "A Python library for Dynamic Ensemble Selection"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -88,7 +88,7 @@ def resolve_metric(metric):
88
88
  # Neighbor finder construction
89
89
  # ---------------------------------------------------------------------------
90
90
 
91
- def make_finder(preset, k, finder=None, **kwargs):
91
+ def make_finder(preset, k, finder=None, distance_metric='euclidean', **kwargs):
92
92
  """
93
93
  Create a NeighborFinder from a preset name or custom finder string.
94
94
 
@@ -100,6 +100,17 @@ def make_finder(preset, k, finder=None, **kwargs):
100
100
  Number of neighbors.
101
101
  finder : str, optional
102
102
  Required when preset='custom'. One of 'knn', 'faiss', 'annoy', 'hnsw'.
103
+ distance_metric : str
104
+ Distance function to use. Default: 'euclidean'. See
105
+ neighbors.list_distance_metrics() for all options and per-backend
106
+ availability.
107
+
108
+ Quick guide:
109
+ 'euclidean' — Best default for most tabular data.
110
+ 'manhattan' — More robust to outliers; good for moderate-to-high dims.
111
+ 'chebyshev' — Max abs diff; use preset='exact' (KNN only).
112
+ 'minkowski' — L1/L2 generalisation; use preset='exact' (KNN only).
113
+ 'cosine' — Direction-based; ideal for embeddings (NLP, vision).
103
114
  **kwargs
104
115
  Forwarded to the finder constructor (e.g. index_type, n_probes).
105
116
  """
@@ -107,7 +118,7 @@ def make_finder(preset, k, finder=None, **kwargs):
107
118
  if finder is None:
108
119
  raise ValueError("Must specify 'finder' when using preset='custom'.")
109
120
  finder_type = finder.lower()
110
- finder_kwargs = {'k': k, **kwargs}
121
+ finder_kwargs = {'k': k, 'distance_metric': distance_metric, **kwargs}
111
122
  else:
112
123
  if preset not in SPEED_PRESETS:
113
124
  raise ValueError(
@@ -117,7 +128,7 @@ def make_finder(preset, k, finder=None, **kwargs):
117
128
  )
118
129
  config = SPEED_PRESETS[preset]
119
130
  finder_type = config['finder']
120
- finder_kwargs = {**config['kwargs'], 'k': k, **kwargs}
131
+ finder_kwargs = {**config['kwargs'], 'k': k, 'distance_metric': distance_metric, **kwargs}
121
132
 
122
133
  if finder_type == 'knn':
123
134
  from deskit.neighbors import KNNNeighborFinder
@@ -35,12 +35,15 @@ class DEWSI(KNNBase):
35
35
  (min-metrics) and 1.0 for classification (max-metrics) at predict time.
36
36
  preset : str
37
37
  Neighbor search preset. Default: 'balanced'. See list_presets().
38
+ distance_metric : str
39
+ Distance function to use for neighbor search. Default: 'euclidean'. See
40
+ neighbors.list_distance_metrics() for all options and per-backend availability.
38
41
  """
39
42
 
40
43
  def __init__(self, task, metric='mae', mode='min', k=10,
41
- threshold=0.5, temperature=None, preset='balanced', **kwargs):
44
+ threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidian', **kwargs):
42
45
  metric_name, metric_fn = resolve_metric(metric)
43
- finder = make_finder(preset, k, **kwargs)
46
+ finder = make_finder(preset, k, distance_metric=distance_metric, **kwargs)
44
47
  super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
45
48
  self.task = task
46
49
  self.threshold = threshold
@@ -37,12 +37,15 @@ class DEWSIV(KNNBase):
37
37
  Defaults to 0.5 for min-metrics, 1.0 otherwise.
38
38
  preset : str
39
39
  Neighbour search preset. Default: 'balanced'. See list_presets().
40
+ distance_metric : str
41
+ Distance function to use for neighbor search. Default: 'euclidean'. See
42
+ neighbors.list_distance_metrics() for all options and per-backend availability.
40
43
  """
41
44
 
42
45
  def __init__(self, task, metric='mae', mode='min', k=10,
43
- threshold=0.5, temperature=None, preset='balanced', **kwargs):
46
+ threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidian', **kwargs):
44
47
  metric_name, metric_fn = resolve_metric(metric)
45
- finder = make_finder(preset, k, **kwargs)
48
+ finder = make_finder(preset, k, distance_metric=distance_metric, **kwargs)
46
49
 
47
50
  self._use_signed = metric_name in _SIGNED_METRICS
48
51
  self._metric_name = metric_name
@@ -40,11 +40,14 @@ class DEWST(KNNBase):
40
40
  the sample falls back to DEWS-I scoring for that model. Default: 0.7.
41
41
  preset : str
42
42
  Neighbour search preset. Default: 'balanced'. See list_presets().
43
+ distance_metric : str
44
+ Distance function to use for neighbor search. Default: 'euclidean'. See
45
+ neighbors.list_distance_metrics() for all options and per-backend availability.
43
46
  """
44
47
 
45
48
  def __init__(self, task, metric='mae', mode='min', k=10,
46
49
  threshold=0.5, temperature=None, r2_threshold=0.7,
47
- preset='balanced', **kwargs):
50
+ preset='balanced', distance_metric='euclidean', **kwargs):
48
51
  metric_name, metric_fn = resolve_metric(metric)
49
52
  finder = make_finder(preset, k, **kwargs)
50
53
 
@@ -31,10 +31,13 @@ class DEWSU(KNNBase):
31
31
  (min-metrics) and 1.0 for classification (max-metrics) at predict time.
32
32
  preset : str
33
33
  Neighbor search preset. Default: 'balanced'. See list_presets().
34
+ distance_metric : str
35
+ Distance function to use for neighbor search. Default: 'euclidean'. See
36
+ neighbors.list_distance_metrics() for all options and per-backend availability.
34
37
  """
35
38
 
36
39
  def __init__(self, task, metric='mae', mode='min', k=10,
37
- threshold=0.5, temperature=None, preset='balanced', **kwargs):
40
+ threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidean', **kwargs):
38
41
  metric_name, metric_fn = resolve_metric(metric)
39
42
  finder = make_finder(preset, k, **kwargs)
40
43
  super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
@@ -37,12 +37,15 @@ class DEWSV(KNNBase):
37
37
  Defaults to 0.5 for min-metrics, 1.0 otherwise.
38
38
  preset : str
39
39
  Neighbour search preset. Default: 'balanced'. See list_presets().
40
+ distance_metric : str
41
+ Distance function to use for neighbor search. Default: 'euclidean'. See
42
+ neighbors.list_distance_metrics() for all options and per-backend availability.
40
43
  """
41
44
 
42
45
  def __init__(self, task, metric='mae', mode='min', k=10,
43
- threshold=0.5, temperature=None, preset='balanced', **kwargs):
46
+ threshold=0.5, temperature=None, preset='balanced', distance_metric='euclidean', **kwargs):
44
47
  metric_name, metric_fn = resolve_metric(metric)
45
- finder = make_finder(preset, k, **kwargs)
48
+ finder = make_finder(preset, k, distance_metric=distance_metric, **kwargs)
46
49
 
47
50
  self._use_signed = metric_name in _SIGNED_METRICS
48
51
  self._metric_name = metric_name
@@ -26,10 +26,13 @@ class KNORAE(KNNBase):
26
26
  Regression: use 1.0.
27
27
  preset : str
28
28
  Neighbor search preset. Default: 'balanced'. See list_presets().
29
+ distance_metric : str
30
+ Distance function to use for neighbor search. Default: 'euclidean'. See
31
+ neighbors.list_distance_metrics() for all options and per-backend availability.
29
32
  """
30
33
 
31
34
  def __init__(self, task, metric='mae', mode='min', k=10,
32
- threshold=0.5, preset='balanced', **kwargs):
35
+ threshold=0.5, preset='balanced', distance_metric='euclidean', **kwargs):
33
36
  metric_name, metric_fn = resolve_metric(metric)
34
37
  finder = make_finder(preset, k, **kwargs)
35
38
  super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
@@ -27,10 +27,13 @@ class KNORAIU(KNNBase):
27
27
  Regression: use 1.0.
28
28
  preset : str
29
29
  Neighbor search preset. Default: 'balanced'. See list_presets().
30
+ distance_metric : str
31
+ Distance function to use for neighbor search. Default: 'euclidean'. See
32
+ neighbors.list_distance_metrics() for all options and per-backend availability.
30
33
  """
31
34
 
32
35
  def __init__(self, task, metric='mae', mode='min', k=10,
33
- threshold=0.5, preset='balanced', **kwargs):
36
+ threshold=0.5, preset='balanced', distance_metric='euclidean', **kwargs):
34
37
  metric_name, metric_fn = resolve_metric(metric)
35
38
  finder = make_finder(preset, k, **kwargs)
36
39
  super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
@@ -27,10 +27,13 @@ class KNORAU(KNNBase):
27
27
  Regression: use 1.0.
28
28
  preset : str
29
29
  Neighbor search preset. Default: 'balanced'. See list_presets().
30
+ distance_metric : str
31
+ Distance function to use for neighbor search. Default: 'euclidean'. See
32
+ neighbors.list_distance_metrics() for all options and per-backend availability.
30
33
  """
31
34
 
32
35
  def __init__(self, task, metric='mae', mode='min', k=10,
33
- threshold=0.5, preset='balanced', **kwargs):
36
+ threshold=0.5, preset='balanced', distance_metric='euclidean', **kwargs):
34
37
  metric_name, metric_fn = resolve_metric(metric)
35
38
  finder = make_finder(preset, k, **kwargs)
36
39
  super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
@@ -19,9 +19,12 @@ class LWSEI(PredictBase):
19
19
  Neighbourhood size. Default: 10.
20
20
  preset : str
21
21
  Neighbour search preset. Default: 'balanced'. See list_presets().
22
+ distance_metric : str
23
+ Distance function to use for neighbor search. Default: 'euclidean'. See
24
+ neighbors.list_distance_metrics() for all options and per-backend availability.
22
25
  """
23
26
 
24
- def __init__(self, task, k=10, preset='balanced', **kwargs):
27
+ def __init__(self, task, k=10, preset='balanced', distance_metric='euclidean', **kwargs):
25
28
  self.task = task
26
29
  self.k = k
27
30
  self._finder = make_finder(preset, k, **kwargs)
@@ -19,9 +19,12 @@ class LWSEU(PredictBase):
19
19
  Neighbourhood size. Default: 10.
20
20
  preset : str
21
21
  Neighbour search preset. Default: 'balanced'. See list_presets().
22
+ distance_metric : str
23
+ Distance function to use for neighbor search. Default: 'euclidean'. See
24
+ neighbors.list_distance_metrics() for all options and per-backend availability.
22
25
  """
23
26
 
24
- def __init__(self, task, k=10, preset='balanced', **kwargs):
27
+ def __init__(self, task, k=10, preset='balanced', distance_metric='euclidean', **kwargs):
25
28
  self.task = task
26
29
  self.k = k
27
30
  self._finder = make_finder(preset, k, **kwargs)
@@ -22,13 +22,16 @@ class OLA(KNNBase):
22
22
  Neighborhood size. Default: 10.
23
23
  preset : str
24
24
  Neighbor search preset. Default: 'balanced'. See list_presets().
25
+ distance_metric : str
26
+ Distance function to use for neighbor search. Default: 'euclidean'. See
27
+ neighbors.list_distance_metrics() for all options and per-backend availability.
25
28
  """
26
29
 
27
30
  def __init__(self, task, metric='mae', mode='min', k=10,
28
- preset='balanced', threshold=None, **kwargs):
31
+ preset='balanced', threshold=None, distance_metric='euclidean', **kwargs):
29
32
  metric_name, metric_fn = resolve_metric(metric)
30
33
  finder = make_finder(preset, k, **kwargs)
31
- super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, , task=task)
34
+ super().__init__(metric=metric_fn, mode=mode, neighbor_finder=finder, task=task)
32
35
  self.task = task
33
36
  self._metric_name = metric_name
34
37
 
@@ -5,6 +5,82 @@ import warnings
5
5
  _FAISS_MIN_SAMPLES_PER_CELL = 40
6
6
 
7
7
 
8
+ # ---------------------------------------------------------------------------
9
+ # Distance metric registry
10
+ # ---------------------------------------------------------------------------
11
+
12
+ # Metrics supported by each backend.
13
+ # 'euclidean' is the universal default and always available.
14
+ #
15
+ # Choosing a distance metric:
16
+ # euclidean – The standard L2 norm. Best default for most tabular data.
17
+ # manhattan – L1 norm (sum of absolute differences). More robust to outliers
18
+ # and tends to work better in moderately high-dimensional spaces
19
+ # because it doesn't square large differences.
20
+ # chebyshev – L∞ norm (maximum absolute difference across features). Useful
21
+ # when a single feature dominating the distance is acceptable;
22
+ # common in game-grid / chess-style distance problems.
23
+ # minkowski – Generalisation of L1/L2 (controlled by p). p=1 → manhattan,
24
+ # p=2 → euclidean. Use when you want to tune between them.
25
+ # cosine – Angle between vectors, ignoring magnitude. Excellent for
26
+ # embeddings (text, image, audio) where direction matters more
27
+ # than raw scale.
28
+
29
+ # Metrics that every backend supports natively.
30
+ _UNIVERSAL_METRICS = {'euclidean', 'manhattan', 'chebyshev', 'minkowski', 'cosine'}
31
+
32
+ # Per-backend metric availability.
33
+ # KNN (sklearn) supports all scipy metrics — this is the complete curated list.
34
+ _KNN_METRICS = _UNIVERSAL_METRICS | {'correlation', 'hamming', 'canberra', 'braycurtis'}
35
+
36
+ # FAISS only has built-in L2 and inner-product (cosine via normalization).
37
+ # All others fall back to a manual compute-then-search path.
38
+ _FAISS_NATIVE_METRICS = {'euclidean', 'cosine'}
39
+ _FAISS_METRICS = _UNIVERSAL_METRICS # remainder handled via sklearn fallback
40
+
41
+ # Annoy metric names (library-specific).
42
+ _ANNOY_METRIC_MAP = {
43
+ 'euclidean': 'euclidean',
44
+ 'manhattan': 'manhattan',
45
+ 'cosine': 'angular',
46
+ 'hamming': 'hamming',
47
+ # chebyshev and minkowski are not natively supported; we warn the user.
48
+ }
49
+ _ANNOY_METRICS = set(_ANNOY_METRIC_MAP)
50
+
51
+ # HNSW (hnswlib) space names.
52
+ _HNSW_METRIC_MAP = {
53
+ 'euclidean': 'l2',
54
+ 'cosine': 'cosine',
55
+ # Others not natively supported in hnswlib; we warn and fall back to l2.
56
+ }
57
+ _HNSW_METRICS = _UNIVERSAL_METRICS # partial — see fit() for fallback note
58
+
59
+ # All metrics callable from the public API.
60
+ ALL_METRICS = _KNN_METRICS
61
+
62
+
63
+ def list_distance_metrics():
64
+ """Print all available distance metrics with per-backend availability."""
65
+ print("\nAvailable Distance Metrics:")
66
+ print("=" * 70)
67
+ rows = [
68
+ ("euclidean", "Default. L2 norm. Best for most tabular data.", "all"),
69
+ ("manhattan", "L1 norm. More robust to outliers; good for high-dim data.", "all"),
70
+ ("chebyshev", "L∞ norm. Max absolute diff across features.", "KNN only (exact preset)"),
71
+ ("minkowski", "Generalises L1/L2 via p-param. Set minkowski_p=<float>.", "KNN only (exact preset)"),
72
+ ("cosine", "Angle between vectors. Ideal for embeddings (NLP, vision).", "all"),
73
+ ("correlation","Pearson correlation distance. Good for time series.", "KNN only (exact preset)"),
74
+ ("hamming", "Fraction of differing components. For binary/categorical data.", "KNN, Annoy"),
75
+ ("canberra", "Weighted L1. Sensitive to small values near zero.", "KNN only (exact preset)"),
76
+ ("braycurtis", "Normalised L1 bounded to [0,1]. Ecological data.", "KNN only (exact preset)"),
77
+ ]
78
+ for name, desc, backends in rows:
79
+ print(f"\n {name:<14} {desc}")
80
+ print(f" {'':14} Backends: {backends}")
81
+ print("\n" + "=" * 70)
82
+
83
+
8
84
  class NeighborFinder:
9
85
  """Base class for neighbor search backends."""
10
86
 
@@ -16,12 +92,37 @@ class NeighborFinder:
16
92
 
17
93
 
18
94
  class KNNNeighborFinder(NeighborFinder):
19
- """Exact nearest neighbors via sklearn NearestNeighbors."""
20
-
21
- def __init__(self, k=10, **kwargs):
95
+ """
96
+ Exact nearest neighbors via sklearn NearestNeighbors.
97
+
98
+ Supports all distance metrics in deskit (euclidean, manhattan, chebyshev,
99
+ minkowski, cosine, correlation, hamming, canberra, braycurtis).
100
+ """
101
+
102
+ def __init__(self, k=10, distance_metric='euclidean', minkowski_p=2, **kwargs):
103
+ """
104
+ Parameters
105
+ ----------
106
+ k : int
107
+ Number of neighbors.
108
+ distance_metric : str
109
+ Distance function to use. One of the metrics returned by
110
+ list_distance_metrics(). Default: 'euclidean'.
111
+ minkowski_p : float
112
+ The p-parameter for the Minkowski metric (p=1 → manhattan,
113
+ p=2 → euclidean). Ignored for all other metrics.
114
+ """
22
115
  if k <= 0:
23
116
  raise ValueError(f"k must be positive, got k={k}")
117
+ metric = distance_metric.lower()
118
+ if metric not in _KNN_METRICS:
119
+ raise ValueError(
120
+ f"distance_metric='{distance_metric}' is not supported by KNNNeighborFinder. "
121
+ f"Available: {sorted(_KNN_METRICS)}."
122
+ )
24
123
  self.n_neighbors = k
124
+ self.distance_metric = metric
125
+ self.minkowski_p = minkowski_p
25
126
  self.kwargs = kwargs
26
127
  self.model = None
27
128
 
@@ -33,7 +134,15 @@ class KNNNeighborFinder(NeighborFinder):
33
134
  f"Cannot find {self.n_neighbors} neighbors in a dataset with only "
34
135
  f"{X.shape[0]} samples. Reduce k to at most {X.shape[0]}."
35
136
  )
36
- self.model = NearestNeighbors(n_neighbors=self.n_neighbors, **self.kwargs)
137
+ metric_kwargs = {}
138
+ if self.distance_metric == 'minkowski':
139
+ metric_kwargs['p'] = self.minkowski_p
140
+ self.model = NearestNeighbors(
141
+ n_neighbors=self.n_neighbors,
142
+ metric=self.distance_metric,
143
+ metric_params=metric_kwargs if metric_kwargs else None,
144
+ **self.kwargs,
145
+ )
37
146
  self.model.fit(X)
38
147
  return self
39
148
 
@@ -48,12 +157,34 @@ class KNNNeighborFinder(NeighborFinder):
48
157
 
49
158
 
50
159
  class FaissNeighborFinder(NeighborFinder):
51
- """Approximate nearest neighbors via FAISS (flat, IVF, or HNSW index)."""
160
+ """
161
+ Approximate nearest neighbors via FAISS (flat, IVF, or HNSW index).
162
+
163
+ Natively supports 'euclidean' and 'cosine'. All other metrics in
164
+ _UNIVERSAL_METRICS fall back to a sklearn-based exact search with a
165
+ warning, so you can still use them without switching presets.
166
+ """
52
167
 
53
168
  def __init__(self, k=10, index_type='flat', n_cells=None, n_probes=50,
54
- hnsw_M=32, hnsw_efConstruction=400, hnsw_efSearch=200):
169
+ hnsw_M=32, hnsw_efConstruction=400, hnsw_efSearch=200,
170
+ distance_metric='euclidean'):
171
+ """
172
+ Parameters
173
+ ----------
174
+ distance_metric : str
175
+ 'euclidean' (default) or 'cosine'. Other metrics fall back to
176
+ exact sklearn search with a warning — use preset='exact' to avoid
177
+ the overhead.
178
+ """
55
179
  if k <= 0:
56
180
  raise ValueError(f"k must be positive, got k={k}")
181
+ metric = distance_metric.lower()
182
+ if metric not in _FAISS_METRICS:
183
+ raise ValueError(
184
+ f"distance_metric='{distance_metric}' is not supported by FaissNeighborFinder. "
185
+ f"Available: {sorted(_FAISS_METRICS)}. "
186
+ f"For other metrics use preset='exact' (KNNNeighborFinder)."
187
+ )
57
188
  self.n_neighbors = k
58
189
  self.index_type = index_type.lower()
59
190
  self.n_cells = n_cells
@@ -61,7 +192,9 @@ class FaissNeighborFinder(NeighborFinder):
61
192
  self.hnsw_M = hnsw_M
62
193
  self.hnsw_efConstruction = hnsw_efConstruction
63
194
  self.hnsw_efSearch = hnsw_efSearch
195
+ self.distance_metric = metric
64
196
  self.index_ = None
197
+ self._fallback_finder = None # used for non-native metrics
65
198
  self._check_availability()
66
199
 
67
200
  def _check_availability(self):
@@ -71,6 +204,13 @@ class FaissNeighborFinder(NeighborFinder):
71
204
  except ImportError:
72
205
  raise ImportError("FAISS not found. Install with: pip install faiss-cpu")
73
206
 
207
+ @staticmethod
208
+ def _l2_normalize(X):
209
+ """Row-wise L2 normalisation in-place (for cosine similarity)."""
210
+ norms = np.linalg.norm(X, axis=1, keepdims=True)
211
+ norms = np.where(norms == 0, 1.0, norms)
212
+ return X / norms.astype(np.float32)
213
+
74
214
  def fit(self, X):
75
215
  X = np.atleast_2d(X).astype(np.float32)
76
216
  n_samples, dim = X.shape
@@ -81,6 +221,24 @@ class FaissNeighborFinder(NeighborFinder):
81
221
  f"{n_samples} samples. Reduce k to at most {n_samples}."
82
222
  )
83
223
 
224
+ # Non-native metrics: delegate entirely to KNNNeighborFinder.
225
+ if self.distance_metric not in _FAISS_NATIVE_METRICS:
226
+ warnings.warn(
227
+ f"distance_metric='{self.distance_metric}' is not natively supported by "
228
+ f"FAISS. Falling back to exact sklearn KNN for this metric. "
229
+ f"Use preset='exact' to avoid this overhead.",
230
+ UserWarning,
231
+ )
232
+ self._fallback_finder = KNNNeighborFinder(
233
+ k=self.n_neighbors, distance_metric=self.distance_metric
234
+ )
235
+ self._fallback_finder.fit(X)
236
+ return self
237
+
238
+ # Cosine similarity: normalise all vectors, then use inner-product index.
239
+ if self.distance_metric == 'cosine':
240
+ X = self._l2_normalize(X)
241
+
84
242
  if self.index_type == 'flat':
85
243
  if dim <= 2:
86
244
  warnings.warn(
@@ -88,14 +246,16 @@ class FaissNeighborFinder(NeighborFinder):
88
246
  f"Consider KNNNeighborFinder for low-dimensional data.",
89
247
  UserWarning
90
248
  )
91
- self.index_ = self.faiss.IndexFlatL2(dim)
249
+ if self.distance_metric == 'cosine':
250
+ self.index_ = self.faiss.IndexFlatIP(dim) # inner product on normalised vecs
251
+ else:
252
+ self.index_ = self.faiss.IndexFlatL2(dim)
92
253
  self.index_.add(X)
93
254
 
94
255
  elif self.index_type == 'ivf':
95
256
  if self.n_cells is None:
96
257
  self.n_cells = min(int(np.sqrt(n_samples)), 4096)
97
258
 
98
- # Reduce n_cells if dataset is too small
99
259
  min_required = self.n_cells * _FAISS_MIN_SAMPLES_PER_CELL
100
260
  if n_samples < min_required:
101
261
  safe_cells = max(1, n_samples // _FAISS_MIN_SAMPLES_PER_CELL)
@@ -121,8 +281,14 @@ class FaissNeighborFinder(NeighborFinder):
121
281
  UserWarning
122
282
  )
123
283
 
124
- quantizer = self.faiss.IndexFlatL2(dim)
125
- self.index_ = self.faiss.IndexIVFFlat(quantizer, dim, self.n_cells)
284
+ if self.distance_metric == 'cosine':
285
+ quantizer = self.faiss.IndexFlatIP(dim)
286
+ self.index_ = self.faiss.IndexIVFFlat(
287
+ quantizer, dim, self.n_cells, self.faiss.METRIC_INNER_PRODUCT
288
+ )
289
+ else:
290
+ quantizer = self.faiss.IndexFlatL2(dim)
291
+ self.index_ = self.faiss.IndexIVFFlat(quantizer, dim, self.n_cells)
126
292
  self.index_.train(X)
127
293
  self.index_.add(X)
128
294
  self.index_.nprobe = effective_probes
@@ -151,21 +317,53 @@ class FaissNeighborFinder(NeighborFinder):
151
317
  X = np.atleast_2d(X).astype(np.float32)
152
318
  if X.shape[0] == 0:
153
319
  return np.empty((0, k), dtype=np.float32), np.empty((0, k), dtype=np.int64)
154
- distances, indices = self.index_.search(X, k)
155
- # FAISS returns squared L2; clamp to 0 before sqrt
156
- return np.sqrt(np.maximum(distances, 0)), indices
157
320
 
321
+ # Non-native metric fallback.
322
+ if self._fallback_finder is not None:
323
+ return self._fallback_finder.kneighbors(X, k=k)
324
+
325
+ if self.distance_metric == 'cosine':
326
+ X = self._l2_normalize(X)
327
+ scores, indices = self.index_.search(X, k)
328
+ # Inner product on normalised vectors: similarity ∈ [-1, 1].
329
+ # Convert to a proper distance (0 = identical, 2 = opposite).
330
+ distances = 1.0 - scores
331
+ else:
332
+ distances, indices = self.index_.search(X, k)
333
+ # FAISS returns squared L2; clamp to 0 before sqrt.
334
+ distances = np.sqrt(np.maximum(distances, 0))
335
+
336
+ return distances.astype(np.float32), indices
158
337
 
159
- class AnnoyNeighborFinder(NeighborFinder):
160
- """Approximate nearest neighbors via Annoy."""
161
338
 
162
- def __init__(self, k=10, n_trees=100, metric='euclidean', search_k=-1):
339
+ class AnnoyNeighborFinder(NeighborFinder):
340
+ """
341
+ Approximate nearest neighbors via Annoy.
342
+
343
+ Supports: euclidean, manhattan, cosine, hamming.
344
+ chebyshev and minkowski are not available in Annoy — use preset='exact' for those.
345
+ """
346
+
347
+ def __init__(self, k=10, n_trees=100, distance_metric='euclidean', search_k=-1):
348
+ """
349
+ Parameters
350
+ ----------
351
+ distance_metric : str
352
+ One of 'euclidean', 'manhattan', 'cosine', 'hamming'. Default: 'euclidean'.
353
+ """
163
354
  if k <= 0:
164
355
  raise ValueError(f"k must be positive, got k={k}")
356
+ metric = distance_metric.lower()
357
+ if metric not in _ANNOY_METRICS:
358
+ raise ValueError(
359
+ f"distance_metric='{distance_metric}' is not supported by AnnoyNeighborFinder. "
360
+ f"Available: {sorted(_ANNOY_METRICS)}. "
361
+ f"For chebyshev or minkowski use preset='exact' (KNNNeighborFinder)."
362
+ )
165
363
  self.k = k
166
364
  self.n_trees = n_trees
167
- self.metric = metric
168
- # Annoy's recommended default; the previous value (n_trees * k * 50)
365
+ self.distance_metric = metric
366
+ # Annoy's recommended default
169
367
  self.search_k = n_trees * k if search_k == -1 else search_k
170
368
  self.index_ = None
171
369
  self.n_samples_ = None
@@ -195,12 +393,7 @@ class AnnoyNeighborFinder(NeighborFinder):
195
393
  UserWarning
196
394
  )
197
395
 
198
- metric_map = {
199
- 'euclidean': 'euclidean', 'l2': 'euclidean',
200
- 'angular': 'angular', 'cosine': 'angular',
201
- 'manhattan': 'manhattan', 'hamming': 'hamming', 'dot': 'dot',
202
- }
203
- self.index_ = self.AnnoyIndex(dim, metric_map.get(self.metric.lower(), 'euclidean'))
396
+ self.index_ = self.AnnoyIndex(dim, _ANNOY_METRIC_MAP[self.distance_metric])
204
397
  for i, vec in enumerate(X):
205
398
  self.index_.add_item(i, vec.tolist())
206
399
  self.index_.build(self.n_trees)
@@ -244,14 +437,33 @@ class AnnoyNeighborFinder(NeighborFinder):
244
437
 
245
438
 
246
439
  class HNSWNeighborFinder(NeighborFinder):
247
- """Approximate nearest neighbors via HNSW (hnswlib or nmslib backend)."""
248
-
249
- def __init__(self, k=10, space='l2', M=32, ef_construction=400,
250
- ef_search=200, backend='hnswlib'):
440
+ """
441
+ Approximate nearest neighbors via HNSW (hnswlib or nmslib backend).
442
+
443
+ Natively supports 'euclidean' and 'cosine'. Manhattan, chebyshev, and
444
+ minkowski are not available in hnswlib/nmslib — use preset='exact' for those.
445
+ """
446
+
447
+ def __init__(self, k=10, M=32, ef_construction=400,
448
+ ef_search=200, backend='hnswlib', distance_metric='euclidean'):
449
+ """
450
+ Parameters
451
+ ----------
452
+ distance_metric : str
453
+ 'euclidean' (default) or 'cosine'. Other metrics are not supported
454
+ natively and will raise an error — use preset='exact' instead.
455
+ """
251
456
  if k <= 0:
252
457
  raise ValueError(f"k must be positive, got k={k}")
458
+ metric = distance_metric.lower()
459
+ if metric not in _HNSW_METRIC_MAP:
460
+ raise ValueError(
461
+ f"distance_metric='{distance_metric}' is not natively supported by "
462
+ f"HNSWNeighborFinder. Available: {sorted(_HNSW_METRIC_MAP)}. "
463
+ f"For other metrics use preset='exact' (KNNNeighborFinder)."
464
+ )
253
465
  self.n_neighbors = k
254
- self.space = space
466
+ self.distance_metric = metric
255
467
  self.M = M
256
468
  self.ef_construction = ef_construction
257
469
  self.ef_search = ef_search
@@ -292,7 +504,8 @@ class HNSWNeighborFinder(NeighborFinder):
292
504
  )
293
505
 
294
506
  if self.backend == 'hnswlib':
295
- self.index_ = self.hnswlib.Index(space=self.space, dim=dim)
507
+ space = _HNSW_METRIC_MAP[self.distance_metric]
508
+ self.index_ = self.hnswlib.Index(space=space, dim=dim)
296
509
  self.index_.init_index(
297
510
  max_elements=n_samples, M=self.M, ef_construction=self.ef_construction
298
511
  )
@@ -300,10 +513,13 @@ class HNSWNeighborFinder(NeighborFinder):
300
513
  self.index_.add_items(X, np.arange(n_samples))
301
514
 
302
515
  else: # nmslib
303
- space_map = {'l2': 'l2', 'cosine': 'cosinesimil', 'ip': 'negdotprod'}
516
+ nmslib_space_map = {
517
+ 'euclidean': 'l2',
518
+ 'cosine': 'cosinesimil',
519
+ }
304
520
  self.index_ = self.nmslib.init(
305
521
  method='hnsw',
306
- space=space_map.get(self.space, 'l2'),
522
+ space=nmslib_space_map.get(self.distance_metric, 'l2'),
307
523
  data_type=self.nmslib.DataType.DENSE_VECTOR
308
524
  )
309
525
  self.index_.addDataPointBatch(X)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deskit
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: A Python library for Dynamic Ensemble Selection
5
5
  Author: Tikhon Vodyanov
6
6
  License-Expression: MIT
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes