deskit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deskit/__init__.py ADDED
@@ -0,0 +1,43 @@
1
+ """
2
+ deskit — Dynamic Ensemble Selection library.
3
+
4
+ Metrics
5
+ -------
6
+ Pass a metric name string:
7
+
8
+ KNNDWS(task='classification', metric='log_loss', mode='min')
9
+
10
+ Or import a metric function directly:
11
+
12
+ from deskit.metrics import log_loss, mae
13
+
14
+ KNNDWS(task='classification', metric=log_loss, mode='min')
15
+
16
+ Available built-in metrics:
17
+ Scalar predictions (pass predict() output):
18
+ 'mae', 'mse', 'rmse', 'accuracy'
19
+
20
+ Probability predictions (pass predict_proba() output):
21
+ 'log_loss', 'prob_correct'
22
+ """
23
+
24
+ from deskit.des.knndws import KNNDWS
25
+ from deskit.des.ola import OLA
26
+ from deskit.des.knorau import KNORAU
27
+ from deskit.des.knorae import KNORAE
28
+ from deskit.des.knoraiu import KNORAIU
29
+ from deskit.router import DynamicRouter
30
+ from deskit._config import SPEED_PRESETS, list_presets
31
+ from deskit.analysis import analyze
32
+
33
+ __all__ = [
34
+ 'KNNDWS',
35
+ 'OLA',
36
+ 'KNORAU',
37
+ 'KNORAE',
38
+ 'KNORAIU',
39
+ 'DynamicRouter',
40
+ 'SPEED_PRESETS',
41
+ 'list_presets',
42
+ 'analyze',
43
+ ]
deskit/_config.py ADDED
@@ -0,0 +1,186 @@
1
+ """
2
+ Internal helpers shared across algorithm classes.
3
+ Not part of the public API.
4
+ """
5
+ import numpy as np
6
+ from deskit.metrics import _METRICS, _PROB_METRICS, _SCALAR_METRICS
7
+ from deskit.utils import to_numpy
8
+
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Speed / accuracy presets
12
+ # ---------------------------------------------------------------------------
13
+
14
+ SPEED_PRESETS = {
15
+ 'exact': {
16
+ 'description': 'Exact nearest neighbors — slowest but 100% accurate',
17
+ 'finder': 'knn',
18
+ 'kwargs': {},
19
+ },
20
+ 'balanced': {
21
+ 'description': 'Good balance of speed and accuracy (~98% recall)',
22
+ 'finder': 'faiss',
23
+ 'kwargs': {'index_type': 'ivf', 'n_probes': 50},
24
+ },
25
+ 'fast': {
26
+ 'description': 'Fast queries with good accuracy (~95% recall)',
27
+ 'finder': 'faiss',
28
+ 'kwargs': {'index_type': 'ivf', 'n_probes': 30},
29
+ },
30
+ 'turbo': {
31
+ 'description': 'Maximum speed, exact results — FAISS flat index',
32
+ 'finder': 'faiss',
33
+ 'kwargs': {'index_type': 'flat'},
34
+ },
35
+ 'high_dim_balanced': {
36
+ 'description': 'High-dimensional data (>100D), balanced',
37
+ 'finder': 'hnsw',
38
+ 'kwargs': {'backend': 'hnswlib', 'M': 32, 'ef_construction': 400, 'ef_search': 200},
39
+ },
40
+ 'high_dim_fast': {
41
+ 'description': 'High-dimensional data (>100D), fast',
42
+ 'finder': 'hnsw',
43
+ 'kwargs': {'backend': 'hnswlib', 'M': 16, 'ef_construction': 200, 'ef_search': 100},
44
+ },
45
+ }
46
+
47
+
48
+ def list_presets():
49
+ """Print all available presets with descriptions and parameters."""
50
+ print("\nAvailable Speed/Accuracy Presets:")
51
+ print("=" * 70)
52
+ for name, config in SPEED_PRESETS.items():
53
+ print(f"\n{name.upper()}\n {config['description']}\n Finder: {config['finder']}")
54
+ if config['kwargs']:
55
+ print(f" Parameters: {config['kwargs']}")
56
+ print("\n" + "=" * 70)
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Metric resolution
61
+ # ---------------------------------------------------------------------------
62
+
63
+ def resolve_metric(metric):
64
+ """
65
+ Convert a metric string or callable to (name_or_None, callable).
66
+
67
+ Returns
68
+ -------
69
+ metric_name : str or None
70
+ String name if metric was passed as a string; None for callables.
71
+ Used later in validate_fit_inputs to check prediction shape.
72
+ metric_fn : callable
73
+ The actual scoring function.
74
+ """
75
+ if isinstance(metric, str):
76
+ name = metric.lower()
77
+ if name not in _METRICS:
78
+ raise ValueError(
79
+ f"Unknown metric '{metric}'. "
80
+ f"Built-in options: {sorted(_METRICS)}. "
81
+ f"Pass a callable for custom metrics."
82
+ )
83
+ return name, _METRICS[name]
84
+ return None, metric
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Neighbor finder construction
89
+ # ---------------------------------------------------------------------------
90
+
91
+ def make_finder(preset, k, finder=None, **kwargs):
92
+ """
93
+ Create a NeighborFinder from a preset name or custom finder string.
94
+
95
+ Parameters
96
+ ----------
97
+ preset : str
98
+ One of the keys in SPEED_PRESETS, or 'custom'.
99
+ k : int
100
+ Number of neighbors.
101
+ finder : str, optional
102
+ Required when preset='custom'. One of 'knn', 'faiss', 'annoy', 'hnsw'.
103
+ **kwargs
104
+ Forwarded to the finder constructor (e.g. index_type, n_probes).
105
+ """
106
+ if preset == 'custom':
107
+ if finder is None:
108
+ raise ValueError("Must specify 'finder' when using preset='custom'.")
109
+ finder_type = finder.lower()
110
+ finder_kwargs = {'k': k, **kwargs}
111
+ else:
112
+ if preset not in SPEED_PRESETS:
113
+ raise ValueError(
114
+ f"Unknown preset '{preset}'. "
115
+ f"Available: {sorted(SPEED_PRESETS)}. "
116
+ f"Or use preset='custom' with an explicit finder."
117
+ )
118
+ config = SPEED_PRESETS[preset]
119
+ finder_type = config['finder']
120
+ finder_kwargs = {**config['kwargs'], 'k': k, **kwargs}
121
+ print(f"Using preset '{preset}': {config['description']}")
122
+
123
+ if finder_type == 'knn':
124
+ from deskit.neighbors import KNNNeighborFinder
125
+ return KNNNeighborFinder(**finder_kwargs)
126
+ elif finder_type == 'faiss':
127
+ from deskit.neighbors import FaissNeighborFinder
128
+ return FaissNeighborFinder(**finder_kwargs)
129
+ elif finder_type == 'annoy':
130
+ from deskit.neighbors import AnnoyNeighborFinder
131
+ return AnnoyNeighborFinder(**finder_kwargs)
132
+ elif finder_type == 'hnsw':
133
+ from deskit.neighbors import HNSWNeighborFinder
134
+ return HNSWNeighborFinder(**finder_kwargs)
135
+ else:
136
+ raise ValueError(f"Unknown finder '{finder_type}'.")
137
+
138
+ # fit() input validation
139
+
140
+ def prep_fit_inputs(features, y, preds_dict, metric_name):
141
+ """
142
+ Convert all fit() inputs to numpy arrays and validate consistency.
143
+
144
+ Returns
145
+ -------
146
+ features, y, preds_dict — all as numpy arrays, ready for KNNBase.fit().
147
+ """
148
+ features = to_numpy(features)
149
+ y = to_numpy(y)
150
+ preds_dict = {name: to_numpy(p) for name, p in preds_dict.items()}
151
+
152
+ n = len(y)
153
+
154
+ if len(features) != n:
155
+ raise ValueError(
156
+ f"features has {len(features)} rows but y has {n} samples. "
157
+ f"All arrays passed to fit() must have the same number of rows."
158
+ )
159
+
160
+ for name, preds in preds_dict.items():
161
+ if len(preds) != n:
162
+ raise ValueError(
163
+ f"preds_dict['{name}'] has {len(preds)} samples but y has {n}. "
164
+ f"Every prediction array must align row-for-row with y."
165
+ )
166
+
167
+ if metric_name is not None:
168
+ sample_preds = next(iter(preds_dict.values()))
169
+ pred_ndim = np.asarray(sample_preds).ndim
170
+
171
+ if metric_name in _PROB_METRICS and pred_ndim != 2:
172
+ raise ValueError(
173
+ f"Metric '{metric_name}' expects probability arrays of shape "
174
+ f"(n_samples, n_classes), but received a {pred_ndim}D array. "
175
+ f"Pass predict_proba() output instead of predict() output."
176
+ )
177
+
178
+ if metric_name in _SCALAR_METRICS and pred_ndim != 1:
179
+ raise ValueError(
180
+ f"Metric '{metric_name}' expects scalar predictions of shape "
181
+ f"(n_samples,), but received a {pred_ndim}D array. "
182
+ f"Pass predict() output, or switch to 'log_loss' / 'prob_correct' "
183
+ f"if you intended to use class probabilities."
184
+ )
185
+
186
+ return features, y, preds_dict
deskit/analysis.py ADDED
@@ -0,0 +1,377 @@
1
+ """
2
+ DES suitability analysis.
3
+
4
+ Analyses a validation set and model pool to estimate whether Dynamic Ensemble
5
+ Selection is likely to help, and by how much.
6
+
7
+ The function prints a formatted report and returns a dict of raw metrics for
8
+ programmatic use.
9
+ """
10
+ import numpy as np
11
+ from deskit._config import resolve_metric, prep_fit_inputs
12
+ from deskit.base.knnbase import KNNBase
13
+ from deskit.neighbors import KNNNeighborFinder
14
+
15
+ # Internal helpers
16
+
17
+ class _AnalysisFitter(KNNBase):
18
+ """Minimal KNNBase subclass used only to build the score matrix."""
19
+ def predict(self, x, temperature=None, threshold=None):
20
+ raise NotImplementedError("_AnalysisFitter is not used for prediction.")
21
+
22
+
23
+ def _entropy(probs):
24
+ """Shannon entropy of a probability distribution, normalised to [0, 1]."""
25
+ probs = np.asarray(probs, dtype=float)
26
+ probs = probs[probs > 0]
27
+ n = len(probs) + (1 - len(probs)) % 1 # guard for single-element
28
+ raw = -np.sum(probs * np.log(probs))
29
+ max_e = np.log(max(len(probs), 2))
30
+ return raw / max_e if max_e > 0 else 0.0
31
+
32
+
33
+ def _bar(value, width=30, fill='█', empty='░'):
34
+ """Render a simple ASCII bar for a value in [0, 1]."""
35
+ n = round(value * width)
36
+ return fill * n + empty * (width - n)
37
+
38
+
39
+ # Public API
40
+
41
+ def analyze(features, y, preds_dict, metric, mode, k=20, verbose=True):
42
+ """
43
+ Analyse a validation set and model pool for DES suitability.
44
+
45
+ Computes four complementary signals and prints an interpreted report.
46
+
47
+ Parameters
48
+ ----------
49
+ features : array-like, shape (n_val, n_features)
50
+ Validation features. Should be the same set you would pass to fit().
51
+ y : array-like, shape (n_val,)
52
+ Validation ground-truth labels or values.
53
+ preds_dict : dict[str, array-like]
54
+ Validation predictions keyed by model name.
55
+ Shape (n_val,) for scalar metrics; (n_val, n_classes) for probability
56
+ metrics (log_loss, prob_correct).
57
+ metric : str or callable
58
+ Same metric you intend to use for the DES router.
59
+ mode : str
60
+ 'min' if lower scores are better (mae, log_loss), 'max' if higher
61
+ (accuracy, prob_correct).
62
+ k : int
63
+ Neighbourhood size. Should match the k you intend to use for routing.
64
+ Default: 20.
65
+ verbose : bool
66
+ If True (default), print the formatted report. If False, return the
67
+ metrics dict silently.
68
+
69
+ Returns
70
+ -------
71
+ dict with keys:
72
+ n_val int Number of validation samples.
73
+ n_features int Number of features.
74
+ n_models int Number of models in the pool.
75
+ k int Neighbourhood size used.
76
+ model_scores dict Mean score per model (higher-is-better scale).
77
+ best_model str Name of the globally best model.
78
+ oracle_gain float Fractional improvement of the local oracle over
79
+ the best single model. 0.0 = no gain possible;
80
+ 0.15 = 15% headroom.
81
+ estimated_gain float Conservative estimate of realised DES gain
82
+ (~5% of oracle_gain; empirical range 1-5% across benchmarks).
83
+ regional_diversity float Entropy of model win shares across
84
+ neighbourhoods, normalised to [0, 1].
85
+ 0 = one model wins everywhere;
86
+ 1 = wins perfectly distributed.
87
+ model_win_shares dict Fraction of neighbourhoods each model wins.
88
+ disagreement float Mean fraction of sample pairs where the
89
+ locally best model differs. Regression only;
90
+ None for probability inputs.
91
+ val_quality float n_val / k, clamped to [0, 1] against a
92
+ target of 100. Proxy for neighbourhood
93
+ estimate reliability.
94
+ recommendation str 'USE DES', 'MAYBE', or 'SKIP DES'.
95
+ reason str Plain-English explanation of recommendation.
96
+ """
97
+ metric_name, metric_fn = resolve_metric(metric)
98
+ features, y, preds_dict = prep_fit_inputs(features, y, preds_dict, metric_name)
99
+
100
+ n_val, n_features = features.shape
101
+ n_models = len(preds_dict)
102
+ model_names = list(preds_dict.keys())
103
+
104
+ # Build score matrix
105
+ finder = KNNNeighborFinder(k=k + 1)
106
+ fitter = _AnalysisFitter(metric=metric_fn, mode=mode, neighbor_finder=finder)
107
+ fitter.fit(features, y, preds_dict)
108
+
109
+ matrix = fitter.matrix
110
+
111
+ # Per-model global mean scores
112
+ model_scores = {name: float(matrix[:, j].mean())
113
+ for j, name in enumerate(model_names)}
114
+ best_model = max(model_scores, key=model_scores.get)
115
+ best_score = model_scores[best_model]
116
+
117
+ # Oracle gain
118
+ oracle_scores = matrix.max(axis=1) # (n_val,)
119
+ oracle_mean = oracle_scores.mean()
120
+
121
+ # Gain is relative to the best single model's mean score.
122
+ if abs(best_score) > 1e-12:
123
+ oracle_gain = float((oracle_mean - best_score) / abs(best_score))
124
+ else:
125
+ oracle_gain = 0.0
126
+
127
+ # Empirically, DES captures roughly 1-5% of oracle headroom.
128
+ estimated_gain = oracle_gain * 0.02
129
+
130
+ # Regional diversity
131
+ _, indices = fitter.model.kneighbors(features, k=k + 1)
132
+ indices = indices[:, 1:] # (n_val, k) — drop self
133
+
134
+ neighborhood_scores = matrix[indices].mean(axis=1) # (n_val, n_models)
135
+ local_winners = np.argmax(neighborhood_scores, axis=1) # (n_val,)
136
+
137
+ # Win share
138
+ win_counts = np.bincount(local_winners, minlength=n_models)
139
+ win_shares = win_counts / n_val
140
+ model_wins = {name: float(win_shares[j]) for j, name in enumerate(model_names)}
141
+
142
+ # Regional diversity
143
+ regional_diversity = float(_entropy(win_shares))
144
+
145
+ # Local uplift
146
+ global_best_idx = model_names.index(best_model)
147
+ nbhd_best = neighborhood_scores.max(axis=1) # (n_val,)
148
+ nbhd_global_best = neighborhood_scores[:, global_best_idx] # (n_val,)
149
+ nbhd_improvement = nbhd_best - nbhd_global_best
150
+
151
+ denom = abs(nbhd_global_best.mean())
152
+ local_uplift = float(nbhd_improvement.mean() / denom) if denom > 1e-12 else 0.0
153
+
154
+ # Model disagreement
155
+ first_preds = next(iter(preds_dict.values()))
156
+ if np.asarray(first_preds).ndim == 1:
157
+ # Pairwise disagreement
158
+ disagreement = float(1.0 - np.sum(win_shares ** 2))
159
+ else:
160
+ disagreement = None
161
+
162
+ # Validation set quality
163
+ # n_val / k >= 100 = stable neighbourhood estimates.
164
+ val_quality = float(min(1.0, (n_val / k) / 100.0))
165
+
166
+ # KNN learnability
167
+ feature_score = float(np.clip((n_features - 2) / 10.0, 0.0, 1.0))
168
+ knn_learnability = val_quality * feature_score
169
+
170
+ # Recommendation
171
+
172
+ if val_quality < 0.2:
173
+ recommendation = 'UNRELIABLE'
174
+ reason = (
175
+ f"Validation set too small relative to k "
176
+ f"(n_val/k = {n_val/k:.0f}, recommended \u2265 100). "
177
+ f"Reduce k or increase the validation set before drawing conclusions."
178
+ )
179
+ elif local_uplift < 0.01 and regional_diversity < 0.4:
180
+ recommendation = 'SKIP DES'
181
+ reason = (
182
+ f"Local uplift is negligible ({local_uplift*100:.1f}%) and regional "
183
+ f"diversity is low ({regional_diversity:.2f}). The routing signal is "
184
+ f"too weak to improve on the best single model. Use it directly."
185
+ )
186
+ elif knn_learnability < 0.25 or local_uplift < 0.05 or regional_diversity < 0.45:
187
+ weak_reasons = []
188
+ if knn_learnability < 0.25:
189
+ weak_reasons.append(
190
+ f"KNN learnability is low ({knn_learnability:.2f}) — "
191
+ f"{'the feature space is too small' if feature_score < 0.5 else 'the val set is too sparse'} "
192
+ f"for stable competence regions"
193
+ )
194
+ if local_uplift < 0.05:
195
+ weak_reasons.append(f"local uplift is modest ({local_uplift*100:.1f}%)")
196
+ if regional_diversity < 0.45:
197
+ weak_reasons.append(f"regional diversity is low ({regional_diversity:.2f})")
198
+ recommendation = 'MAYBE \u2014 try Global Ensemble first'
199
+ reason = (
200
+ f"{'; '.join(weak_reasons).capitalize()}. "
201
+ f"A fixed-weight global ensemble may capture most of the gain "
202
+ f"with less variance. Run both and compare on held-out test data."
203
+ )
204
+ else:
205
+ recommendation = 'USE DES'
206
+ reason = (
207
+ f"Local uplift ({local_uplift*100:.1f}%), regional diversity "
208
+ f"({regional_diversity:.2f}), and KNN learnability ({knn_learnability:.2f}) "
209
+ f"are all strong. DES is likely to improve on both the best single "
210
+ f"model and the global ensemble."
211
+ )
212
+
213
+ # Assemble results
214
+ global_best_win_share = float(model_wins[best_model])
215
+
216
+ result = {
217
+ 'n_val': n_val,
218
+ 'n_features': n_features,
219
+ 'n_models': n_models,
220
+ 'k': k,
221
+ 'model_scores': model_scores,
222
+ 'best_model': best_model,
223
+ 'oracle_gain': oracle_gain,
224
+ 'estimated_gain': estimated_gain,
225
+ 'local_uplift': local_uplift,
226
+ 'regional_diversity': regional_diversity,
227
+ 'global_best_win_share': global_best_win_share,
228
+ 'model_win_shares': model_wins,
229
+ 'disagreement': disagreement,
230
+ 'val_quality': val_quality,
231
+ 'feature_score': feature_score,
232
+ 'knn_learnability': knn_learnability,
233
+ 'recommendation': recommendation,
234
+ 'reason': reason,
235
+ }
236
+
237
+ if verbose:
238
+ _print_report(result)
239
+
240
+ return result
241
+
242
+ # Report formatting
243
+
244
+ def _print_report(r):
245
+ W = 72
246
+
247
+ def _section(title):
248
+ print(f"\n {title}")
249
+ print(f" {'─' * (W - 4)}")
250
+
251
+ print(f"\n {'━' * W}")
252
+ print(f" DES Suitability Analysis")
253
+ print(f" {'━' * W}")
254
+ print(f" {r['n_val']:,} val samples · {r['n_features']} features · "
255
+ f"{r['n_models']} models · k = {r['k']}")
256
+
257
+ # Model scores
258
+ _section("Model scores (higher-is-better scale)")
259
+ scores = list(r['model_scores'].values())
260
+ s_min = min(scores)
261
+ s_max = max(scores)
262
+ s_range = s_max - s_min
263
+ for name, score in r['model_scores'].items():
264
+ marker = ' ← best' if name == r['best_model'] else ''
265
+ # Bar shows each model's relative position in the pool score range.
266
+ bar_val = (score - s_min) / s_range if s_range > 0 else 1.0
267
+ bar = _bar(bar_val)
268
+ print(f" {name:<22} {score:+.4f} {bar}{marker}")
269
+
270
+ # Oracle gain
271
+ _section("Oracle / local uplift (headroom and realisable gain)")
272
+ og = r['oracle_gain']
273
+ lu = r['local_uplift']
274
+ eg = r['estimated_gain']
275
+ bar_og = _bar(min(og / 0.20, 1.0))
276
+ bar_lu = _bar(min(lu / 0.10, 1.0))
277
+ print(f" Oracle gain {og*100:+6.2f}% {bar_og}")
278
+ print(f" Local uplift {lu*100:+6.2f}% {bar_lu}")
279
+ print(f" Estimated DES {eg*100:+6.2f}% (≤2% of oracle; empirical range 1–5%)")
280
+ print()
281
+ if og < 0.02:
282
+ note = "Very little headroom — models already agree on most samples."
283
+ elif og < 0.05:
284
+ note = "Moderate headroom — DES may help, depends on regional structure."
285
+ elif og < 0.12:
286
+ note = "Good headroom — DES has meaningful potential."
287
+ else:
288
+ note = "Large headroom — strong case for DES."
289
+ print(f" Oracle gain is the per-sample ceiling. Local uplift is the")
290
+ print(f" neighbourhood-level routing advantage that must generalise to")
291
+ print(f" test data. Low local uplift with high oracle gain usually means")
292
+ print(f" the routing signal is noisy (val set too small or too few features).")
293
+
294
+ # Regional diversity
295
+ _section("Regional diversity (do different models win in different areas?)")
296
+ rd = r['regional_diversity']
297
+ bar_rd = _bar(rd)
298
+ print(f" Diversity score {rd:.3f} {bar_rd}")
299
+ print()
300
+ print(f" {'Model':<22} {'Win share':>10} {'Neighbourhoods'}")
301
+ print(f" {'─'*22} {'─'*10} {'─'*14}")
302
+ for name, share in sorted(r['model_win_shares'].items(),
303
+ key=lambda x: -x[1]):
304
+ bar = _bar(share, width=20)
305
+ print(f" {name:<22} {share*100:>9.1f}% {bar}")
306
+ print()
307
+ if rd < 0.35:
308
+ note = "Low diversity — one model dominates most regions."
309
+ elif rd < 0.65:
310
+ note = "Moderate diversity — some regional structure present."
311
+ else:
312
+ note = "High diversity — models have distinct regional strengths."
313
+ print(f" {note}")
314
+ # Warn when the local neighbourhood winner differs from the global best.
315
+ gbws = r['global_best_win_share']
316
+ local_leader = max(r['model_win_shares'], key=r['model_win_shares'].get)
317
+ if local_leader != r['best_model']:
318
+ share_leader = r['model_win_shares'][local_leader] * 100
319
+ print()
320
+ print(f" ⚠ Local winner '{local_leader}' ({share_leader:.0f}% of regions)")
321
+ print(f" differs from global best '{r['best_model']}' ({gbws*100:.0f}%).")
322
+ print(f" A weaker model dominating locally often signals noisy")
323
+ print(f" neighbourhood estimates. Cross-check local uplift.")
324
+
325
+ # Disagreement
326
+ if r['disagreement'] is not None:
327
+ _section("Model disagreement (how often does the locally-best model vary?)")
328
+ d = r['disagreement']
329
+ bar_d = _bar(d)
330
+ print(f" Disagreement {d:.3f} {bar_d}")
331
+ print()
332
+ if d < 0.3:
333
+ note = "Low disagreement — models make similar errors."
334
+ elif d < 0.6:
335
+ note = "Moderate disagreement — routing has something to work with."
336
+ else:
337
+ note = "High disagreement — strong routing signal available."
338
+ print(f" {note}")
339
+
340
+ # Validation set quality & KNN learnability
341
+ _section("KNN learnability (can the router learn stable competence regions?)")
342
+ vq = r['val_quality']
343
+ fs = r['feature_score']
344
+ kl = r['knn_learnability']
345
+ ratio = r['n_val'] / r['k']
346
+ bar_vq = _bar(vq)
347
+ bar_fs = _bar(fs)
348
+ bar_kl = _bar(kl)
349
+ vq_status = '✓' if ratio >= 100 else ('~' if ratio >= 50 else '✗')
350
+ fs_status = '✓' if r['n_features'] >= 12 else ('~' if r['n_features'] >= 6 else '✗')
351
+ kl_status = '✓' if kl >= 0.5 else ('~' if kl >= 0.25 else '✗')
352
+ print(f" n_val / k = {ratio:>5.0f} {bar_vq} {vq_status} (sample density)")
353
+ print(f" n_features = {r['n_features']:>5d} {bar_fs} {fs_status} (distance structure)")
354
+ print(f" learnability = {kl:.2f} {bar_kl} {kl_status} (combined)")
355
+ print()
356
+ if kl < 0.25:
357
+ note = ("Low — KNN is unlikely to find stable competence regions. "
358
+ "Consider Global Ensemble instead.")
359
+ elif kl < 0.5:
360
+ note = "Moderate — competence regions may be noisy. Treat results with caution."
361
+ else:
362
+ note = "Good — KNN has enough data and feature structure to route reliably."
363
+ print(f" {note}")
364
+
365
+ # Recommendation
366
+ print(f"\n {'━' * W}")
367
+ rec = r['recommendation']
368
+ symbols = {
369
+ 'USE DES': '✓',
370
+ 'MAYBE — try Global Ensemble first': '~',
371
+ 'SKIP DES': '✗',
372
+ 'UNRELIABLE': '?',
373
+ }
374
+ symbol = symbols.get(rec, ' ')
375
+ print(f" Recommendation: [{symbol}] {rec}")
376
+ print(f"\n {r['reason']}")
377
+ print(f" {'━' * W}\n")
@@ -0,0 +1,4 @@
1
+ from deskit.base.base import BaseRouter
2
+ from deskit.base.knnbase import KNNBase
3
+
4
+ __all__ = ['BaseRouter', 'KNNBase']
deskit/base/base.py ADDED
@@ -0,0 +1,11 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class BaseRouter(ABC):
5
+ @abstractmethod
6
+ def fit(self, features, y, preds_dict):
7
+ pass
8
+
9
+ @abstractmethod
10
+ def predict(self, x):
11
+ pass
deskit/base/knnbase.py ADDED
@@ -0,0 +1,54 @@
1
+ from deskit.base.base import BaseRouter
2
+ import numpy as np
3
+
4
+
5
+ class KNNBase(BaseRouter):
6
+ """
7
+ Base for KNN-based DES algorithms.
8
+ """
9
+
10
+ def __init__(self, metric, mode='max', neighbor_finder=None):
11
+ """
12
+ Parameters
13
+ ----------
14
+ metric : callable
15
+ Per-sample scoring function: (y_true, y_pred) -> float.
16
+ mode : str
17
+ 'max' if higher scores are better, 'min' if lower.
18
+ neighbor_finder : NeighborFinder
19
+ Backend used for neighborhood queries.
20
+ """
21
+ self.metric = metric
22
+ self.mode = mode
23
+ self.model = neighbor_finder
24
+ self.matrix = None # (n_val, n_models); higher is always better
25
+ self.models = None # ordered list of model names
26
+
27
+ def _compute_scores(self, y, preds):
28
+ """
29
+ Return a 1D array of per-sample metric scores.
30
+
31
+ preds may be 1D (scalar predictions) or 2D (probability arrays, one
32
+ row per sample)
33
+ """
34
+ preds = np.asarray(preds)
35
+ if preds.ndim == 2:
36
+ return np.array([self.metric(y[i], preds[i]) for i in range(len(y))])
37
+ return np.vectorize(self.metric)(y, preds)
38
+
39
+ def fit(self, features, y, preds_dict):
40
+ """
41
+ Build the score matrix and fit the neighbor index.
42
+
43
+ This method expects pre-validated numpy arrays.
44
+ """
45
+ self.models = list(preds_dict.keys())
46
+ n_val = len(y)
47
+ n_models = len(self.models)
48
+ self.matrix = np.zeros((n_val, n_models))
49
+
50
+ for j, name in enumerate(self.models):
51
+ scores = self._compute_scores(y, preds_dict[name])
52
+ self.matrix[:, j] = scores if self.mode == 'max' else -scores
53
+
54
+ self.model.fit(features)
deskit/des/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from deskit.des.knndws import KNNDWS
2
+ from deskit.des.ola import OLA
3
+ from deskit.des.knorau import KNORAU
4
+ from deskit.des.knorae import KNORAE
5
+ from deskit.des.knoraiu import KNORAIU
6
+
7
+ __all__ = ['KNNDWS', 'OLA', 'KNORAU', 'KNORAE', 'KNORAIU']