deskit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deskit/__init__.py +43 -0
- deskit/_config.py +186 -0
- deskit/analysis.py +377 -0
- deskit/base/__init__.py +4 -0
- deskit/base/base.py +11 -0
- deskit/base/knnbase.py +54 -0
- deskit/des/__init__.py +7 -0
- deskit/des/knndws.py +122 -0
- deskit/des/knorae.py +120 -0
- deskit/des/knoraiu.py +113 -0
- deskit/des/knorau.py +107 -0
- deskit/des/ola.py +85 -0
- deskit/metrics.py +75 -0
- deskit/neighbors.py +335 -0
- deskit/router.py +184 -0
- deskit/utils.py +27 -0
- deskit-0.1.0.dist-info/METADATA +289 -0
- deskit-0.1.0.dist-info/RECORD +21 -0
- deskit-0.1.0.dist-info/WHEEL +5 -0
- deskit-0.1.0.dist-info/licenses/LICENSE +21 -0
- deskit-0.1.0.dist-info/top_level.txt +1 -0
deskit/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
deskit — Dynamic Ensemble Selection library.
|
|
3
|
+
|
|
4
|
+
Metrics
|
|
5
|
+
-------
|
|
6
|
+
Pass a metric name string:
|
|
7
|
+
|
|
8
|
+
KNNDWS(task='classification', metric='log_loss', mode='min')
|
|
9
|
+
|
|
10
|
+
Or import a metric function directly:
|
|
11
|
+
|
|
12
|
+
from deskit.metrics import log_loss, mae
|
|
13
|
+
|
|
14
|
+
KNNDWS(task='classification', metric=log_loss, mode='min')
|
|
15
|
+
|
|
16
|
+
Available built-in metrics:
|
|
17
|
+
Scalar predictions (pass predict() output):
|
|
18
|
+
'mae', 'mse', 'rmse', 'accuracy'
|
|
19
|
+
|
|
20
|
+
Probability predictions (pass predict_proba() output):
|
|
21
|
+
'log_loss', 'prob_correct'
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from deskit.des.knndws import KNNDWS
|
|
25
|
+
from deskit.des.ola import OLA
|
|
26
|
+
from deskit.des.knorau import KNORAU
|
|
27
|
+
from deskit.des.knorae import KNORAE
|
|
28
|
+
from deskit.des.knoraiu import KNORAIU
|
|
29
|
+
from deskit.router import DynamicRouter
|
|
30
|
+
from deskit._config import SPEED_PRESETS, list_presets
|
|
31
|
+
from deskit.analysis import analyze
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
'KNNDWS',
|
|
35
|
+
'OLA',
|
|
36
|
+
'KNORAU',
|
|
37
|
+
'KNORAE',
|
|
38
|
+
'KNORAIU',
|
|
39
|
+
'DynamicRouter',
|
|
40
|
+
'SPEED_PRESETS',
|
|
41
|
+
'list_presets',
|
|
42
|
+
'analyze',
|
|
43
|
+
]
|
deskit/_config.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Internal helpers shared across algorithm classes.
|
|
3
|
+
Not part of the public API.
|
|
4
|
+
"""
|
|
5
|
+
import numpy as np
|
|
6
|
+
from deskit.metrics import _METRICS, _PROB_METRICS, _SCALAR_METRICS
|
|
7
|
+
from deskit.utils import to_numpy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# ---------------------------------------------------------------------------
|
|
11
|
+
# Speed / accuracy presets
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
|
|
14
|
+
SPEED_PRESETS = {
|
|
15
|
+
'exact': {
|
|
16
|
+
'description': 'Exact nearest neighbors — slowest but 100% accurate',
|
|
17
|
+
'finder': 'knn',
|
|
18
|
+
'kwargs': {},
|
|
19
|
+
},
|
|
20
|
+
'balanced': {
|
|
21
|
+
'description': 'Good balance of speed and accuracy (~98% recall)',
|
|
22
|
+
'finder': 'faiss',
|
|
23
|
+
'kwargs': {'index_type': 'ivf', 'n_probes': 50},
|
|
24
|
+
},
|
|
25
|
+
'fast': {
|
|
26
|
+
'description': 'Fast queries with good accuracy (~95% recall)',
|
|
27
|
+
'finder': 'faiss',
|
|
28
|
+
'kwargs': {'index_type': 'ivf', 'n_probes': 30},
|
|
29
|
+
},
|
|
30
|
+
'turbo': {
|
|
31
|
+
'description': 'Maximum speed, exact results — FAISS flat index',
|
|
32
|
+
'finder': 'faiss',
|
|
33
|
+
'kwargs': {'index_type': 'flat'},
|
|
34
|
+
},
|
|
35
|
+
'high_dim_balanced': {
|
|
36
|
+
'description': 'High-dimensional data (>100D), balanced',
|
|
37
|
+
'finder': 'hnsw',
|
|
38
|
+
'kwargs': {'backend': 'hnswlib', 'M': 32, 'ef_construction': 400, 'ef_search': 200},
|
|
39
|
+
},
|
|
40
|
+
'high_dim_fast': {
|
|
41
|
+
'description': 'High-dimensional data (>100D), fast',
|
|
42
|
+
'finder': 'hnsw',
|
|
43
|
+
'kwargs': {'backend': 'hnswlib', 'M': 16, 'ef_construction': 200, 'ef_search': 100},
|
|
44
|
+
},
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def list_presets():
|
|
49
|
+
"""Print all available presets with descriptions and parameters."""
|
|
50
|
+
print("\nAvailable Speed/Accuracy Presets:")
|
|
51
|
+
print("=" * 70)
|
|
52
|
+
for name, config in SPEED_PRESETS.items():
|
|
53
|
+
print(f"\n{name.upper()}\n {config['description']}\n Finder: {config['finder']}")
|
|
54
|
+
if config['kwargs']:
|
|
55
|
+
print(f" Parameters: {config['kwargs']}")
|
|
56
|
+
print("\n" + "=" * 70)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ---------------------------------------------------------------------------
|
|
60
|
+
# Metric resolution
|
|
61
|
+
# ---------------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
def resolve_metric(metric):
|
|
64
|
+
"""
|
|
65
|
+
Convert a metric string or callable to (name_or_None, callable).
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
metric_name : str or None
|
|
70
|
+
String name if metric was passed as a string; None for callables.
|
|
71
|
+
Used later in validate_fit_inputs to check prediction shape.
|
|
72
|
+
metric_fn : callable
|
|
73
|
+
The actual scoring function.
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(metric, str):
|
|
76
|
+
name = metric.lower()
|
|
77
|
+
if name not in _METRICS:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Unknown metric '{metric}'. "
|
|
80
|
+
f"Built-in options: {sorted(_METRICS)}. "
|
|
81
|
+
f"Pass a callable for custom metrics."
|
|
82
|
+
)
|
|
83
|
+
return name, _METRICS[name]
|
|
84
|
+
return None, metric
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# ---------------------------------------------------------------------------
|
|
88
|
+
# Neighbor finder construction
|
|
89
|
+
# ---------------------------------------------------------------------------
|
|
90
|
+
|
|
91
|
+
def make_finder(preset, k, finder=None, **kwargs):
|
|
92
|
+
"""
|
|
93
|
+
Create a NeighborFinder from a preset name or custom finder string.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
preset : str
|
|
98
|
+
One of the keys in SPEED_PRESETS, or 'custom'.
|
|
99
|
+
k : int
|
|
100
|
+
Number of neighbors.
|
|
101
|
+
finder : str, optional
|
|
102
|
+
Required when preset='custom'. One of 'knn', 'faiss', 'annoy', 'hnsw'.
|
|
103
|
+
**kwargs
|
|
104
|
+
Forwarded to the finder constructor (e.g. index_type, n_probes).
|
|
105
|
+
"""
|
|
106
|
+
if preset == 'custom':
|
|
107
|
+
if finder is None:
|
|
108
|
+
raise ValueError("Must specify 'finder' when using preset='custom'.")
|
|
109
|
+
finder_type = finder.lower()
|
|
110
|
+
finder_kwargs = {'k': k, **kwargs}
|
|
111
|
+
else:
|
|
112
|
+
if preset not in SPEED_PRESETS:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Unknown preset '{preset}'. "
|
|
115
|
+
f"Available: {sorted(SPEED_PRESETS)}. "
|
|
116
|
+
f"Or use preset='custom' with an explicit finder."
|
|
117
|
+
)
|
|
118
|
+
config = SPEED_PRESETS[preset]
|
|
119
|
+
finder_type = config['finder']
|
|
120
|
+
finder_kwargs = {**config['kwargs'], 'k': k, **kwargs}
|
|
121
|
+
print(f"Using preset '{preset}': {config['description']}")
|
|
122
|
+
|
|
123
|
+
if finder_type == 'knn':
|
|
124
|
+
from deskit.neighbors import KNNNeighborFinder
|
|
125
|
+
return KNNNeighborFinder(**finder_kwargs)
|
|
126
|
+
elif finder_type == 'faiss':
|
|
127
|
+
from deskit.neighbors import FaissNeighborFinder
|
|
128
|
+
return FaissNeighborFinder(**finder_kwargs)
|
|
129
|
+
elif finder_type == 'annoy':
|
|
130
|
+
from deskit.neighbors import AnnoyNeighborFinder
|
|
131
|
+
return AnnoyNeighborFinder(**finder_kwargs)
|
|
132
|
+
elif finder_type == 'hnsw':
|
|
133
|
+
from deskit.neighbors import HNSWNeighborFinder
|
|
134
|
+
return HNSWNeighborFinder(**finder_kwargs)
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError(f"Unknown finder '{finder_type}'.")
|
|
137
|
+
|
|
138
|
+
# fit() input validation
|
|
139
|
+
|
|
140
|
+
def prep_fit_inputs(features, y, preds_dict, metric_name):
|
|
141
|
+
"""
|
|
142
|
+
Convert all fit() inputs to numpy arrays and validate consistency.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
features, y, preds_dict — all as numpy arrays, ready for KNNBase.fit().
|
|
147
|
+
"""
|
|
148
|
+
features = to_numpy(features)
|
|
149
|
+
y = to_numpy(y)
|
|
150
|
+
preds_dict = {name: to_numpy(p) for name, p in preds_dict.items()}
|
|
151
|
+
|
|
152
|
+
n = len(y)
|
|
153
|
+
|
|
154
|
+
if len(features) != n:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"features has {len(features)} rows but y has {n} samples. "
|
|
157
|
+
f"All arrays passed to fit() must have the same number of rows."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
for name, preds in preds_dict.items():
|
|
161
|
+
if len(preds) != n:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"preds_dict['{name}'] has {len(preds)} samples but y has {n}. "
|
|
164
|
+
f"Every prediction array must align row-for-row with y."
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if metric_name is not None:
|
|
168
|
+
sample_preds = next(iter(preds_dict.values()))
|
|
169
|
+
pred_ndim = np.asarray(sample_preds).ndim
|
|
170
|
+
|
|
171
|
+
if metric_name in _PROB_METRICS and pred_ndim != 2:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Metric '{metric_name}' expects probability arrays of shape "
|
|
174
|
+
f"(n_samples, n_classes), but received a {pred_ndim}D array. "
|
|
175
|
+
f"Pass predict_proba() output instead of predict() output."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if metric_name in _SCALAR_METRICS and pred_ndim != 1:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"Metric '{metric_name}' expects scalar predictions of shape "
|
|
181
|
+
f"(n_samples,), but received a {pred_ndim}D array. "
|
|
182
|
+
f"Pass predict() output, or switch to 'log_loss' / 'prob_correct' "
|
|
183
|
+
f"if you intended to use class probabilities."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return features, y, preds_dict
|
deskit/analysis.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DES suitability analysis.
|
|
3
|
+
|
|
4
|
+
Analyses a validation set and model pool to estimate whether Dynamic Ensemble
|
|
5
|
+
Selection is likely to help, and by how much.
|
|
6
|
+
|
|
7
|
+
The function prints a formatted report and returns a dict of raw metrics for
|
|
8
|
+
programmatic use.
|
|
9
|
+
"""
|
|
10
|
+
import numpy as np
|
|
11
|
+
from deskit._config import resolve_metric, prep_fit_inputs
|
|
12
|
+
from deskit.base.knnbase import KNNBase
|
|
13
|
+
from deskit.neighbors import KNNNeighborFinder
|
|
14
|
+
|
|
15
|
+
# Internal helpers
|
|
16
|
+
|
|
17
|
+
class _AnalysisFitter(KNNBase):
|
|
18
|
+
"""Minimal KNNBase subclass used only to build the score matrix."""
|
|
19
|
+
def predict(self, x, temperature=None, threshold=None):
|
|
20
|
+
raise NotImplementedError("_AnalysisFitter is not used for prediction.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _entropy(probs):
|
|
24
|
+
"""Shannon entropy of a probability distribution, normalised to [0, 1]."""
|
|
25
|
+
probs = np.asarray(probs, dtype=float)
|
|
26
|
+
probs = probs[probs > 0]
|
|
27
|
+
n = len(probs) + (1 - len(probs)) % 1 # guard for single-element
|
|
28
|
+
raw = -np.sum(probs * np.log(probs))
|
|
29
|
+
max_e = np.log(max(len(probs), 2))
|
|
30
|
+
return raw / max_e if max_e > 0 else 0.0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _bar(value, width=30, fill='█', empty='░'):
|
|
34
|
+
"""Render a simple ASCII bar for a value in [0, 1]."""
|
|
35
|
+
n = round(value * width)
|
|
36
|
+
return fill * n + empty * (width - n)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Public API
|
|
40
|
+
|
|
41
|
+
def analyze(features, y, preds_dict, metric, mode, k=20, verbose=True):
|
|
42
|
+
"""
|
|
43
|
+
Analyse a validation set and model pool for DES suitability.
|
|
44
|
+
|
|
45
|
+
Computes four complementary signals and prints an interpreted report.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
features : array-like, shape (n_val, n_features)
|
|
50
|
+
Validation features. Should be the same set you would pass to fit().
|
|
51
|
+
y : array-like, shape (n_val,)
|
|
52
|
+
Validation ground-truth labels or values.
|
|
53
|
+
preds_dict : dict[str, array-like]
|
|
54
|
+
Validation predictions keyed by model name.
|
|
55
|
+
Shape (n_val,) for scalar metrics; (n_val, n_classes) for probability
|
|
56
|
+
metrics (log_loss, prob_correct).
|
|
57
|
+
metric : str or callable
|
|
58
|
+
Same metric you intend to use for the DES router.
|
|
59
|
+
mode : str
|
|
60
|
+
'min' if lower scores are better (mae, log_loss), 'max' if higher
|
|
61
|
+
(accuracy, prob_correct).
|
|
62
|
+
k : int
|
|
63
|
+
Neighbourhood size. Should match the k you intend to use for routing.
|
|
64
|
+
Default: 20.
|
|
65
|
+
verbose : bool
|
|
66
|
+
If True (default), print the formatted report. If False, return the
|
|
67
|
+
metrics dict silently.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
dict with keys:
|
|
72
|
+
n_val int Number of validation samples.
|
|
73
|
+
n_features int Number of features.
|
|
74
|
+
n_models int Number of models in the pool.
|
|
75
|
+
k int Neighbourhood size used.
|
|
76
|
+
model_scores dict Mean score per model (higher-is-better scale).
|
|
77
|
+
best_model str Name of the globally best model.
|
|
78
|
+
oracle_gain float Fractional improvement of the local oracle over
|
|
79
|
+
the best single model. 0.0 = no gain possible;
|
|
80
|
+
0.15 = 15% headroom.
|
|
81
|
+
estimated_gain float Conservative estimate of realised DES gain
|
|
82
|
+
(~5% of oracle_gain; empirical range 1-5% across benchmarks).
|
|
83
|
+
regional_diversity float Entropy of model win shares across
|
|
84
|
+
neighbourhoods, normalised to [0, 1].
|
|
85
|
+
0 = one model wins everywhere;
|
|
86
|
+
1 = wins perfectly distributed.
|
|
87
|
+
model_win_shares dict Fraction of neighbourhoods each model wins.
|
|
88
|
+
disagreement float Mean fraction of sample pairs where the
|
|
89
|
+
locally best model differs. Regression only;
|
|
90
|
+
None for probability inputs.
|
|
91
|
+
val_quality float n_val / k, clamped to [0, 1] against a
|
|
92
|
+
target of 100. Proxy for neighbourhood
|
|
93
|
+
estimate reliability.
|
|
94
|
+
recommendation str 'USE DES', 'MAYBE', or 'SKIP DES'.
|
|
95
|
+
reason str Plain-English explanation of recommendation.
|
|
96
|
+
"""
|
|
97
|
+
metric_name, metric_fn = resolve_metric(metric)
|
|
98
|
+
features, y, preds_dict = prep_fit_inputs(features, y, preds_dict, metric_name)
|
|
99
|
+
|
|
100
|
+
n_val, n_features = features.shape
|
|
101
|
+
n_models = len(preds_dict)
|
|
102
|
+
model_names = list(preds_dict.keys())
|
|
103
|
+
|
|
104
|
+
# Build score matrix
|
|
105
|
+
finder = KNNNeighborFinder(k=k + 1)
|
|
106
|
+
fitter = _AnalysisFitter(metric=metric_fn, mode=mode, neighbor_finder=finder)
|
|
107
|
+
fitter.fit(features, y, preds_dict)
|
|
108
|
+
|
|
109
|
+
matrix = fitter.matrix
|
|
110
|
+
|
|
111
|
+
# Per-model global mean scores
|
|
112
|
+
model_scores = {name: float(matrix[:, j].mean())
|
|
113
|
+
for j, name in enumerate(model_names)}
|
|
114
|
+
best_model = max(model_scores, key=model_scores.get)
|
|
115
|
+
best_score = model_scores[best_model]
|
|
116
|
+
|
|
117
|
+
# Oracle gain
|
|
118
|
+
oracle_scores = matrix.max(axis=1) # (n_val,)
|
|
119
|
+
oracle_mean = oracle_scores.mean()
|
|
120
|
+
|
|
121
|
+
# Gain is relative to the best single model's mean score.
|
|
122
|
+
if abs(best_score) > 1e-12:
|
|
123
|
+
oracle_gain = float((oracle_mean - best_score) / abs(best_score))
|
|
124
|
+
else:
|
|
125
|
+
oracle_gain = 0.0
|
|
126
|
+
|
|
127
|
+
# Empirically, DES captures roughly 1-5% of oracle headroom.
|
|
128
|
+
estimated_gain = oracle_gain * 0.02
|
|
129
|
+
|
|
130
|
+
# Regional diversity
|
|
131
|
+
_, indices = fitter.model.kneighbors(features, k=k + 1)
|
|
132
|
+
indices = indices[:, 1:] # (n_val, k) — drop self
|
|
133
|
+
|
|
134
|
+
neighborhood_scores = matrix[indices].mean(axis=1) # (n_val, n_models)
|
|
135
|
+
local_winners = np.argmax(neighborhood_scores, axis=1) # (n_val,)
|
|
136
|
+
|
|
137
|
+
# Win share
|
|
138
|
+
win_counts = np.bincount(local_winners, minlength=n_models)
|
|
139
|
+
win_shares = win_counts / n_val
|
|
140
|
+
model_wins = {name: float(win_shares[j]) for j, name in enumerate(model_names)}
|
|
141
|
+
|
|
142
|
+
# Regional diversity
|
|
143
|
+
regional_diversity = float(_entropy(win_shares))
|
|
144
|
+
|
|
145
|
+
# Local uplift
|
|
146
|
+
global_best_idx = model_names.index(best_model)
|
|
147
|
+
nbhd_best = neighborhood_scores.max(axis=1) # (n_val,)
|
|
148
|
+
nbhd_global_best = neighborhood_scores[:, global_best_idx] # (n_val,)
|
|
149
|
+
nbhd_improvement = nbhd_best - nbhd_global_best
|
|
150
|
+
|
|
151
|
+
denom = abs(nbhd_global_best.mean())
|
|
152
|
+
local_uplift = float(nbhd_improvement.mean() / denom) if denom > 1e-12 else 0.0
|
|
153
|
+
|
|
154
|
+
# Model disagreement
|
|
155
|
+
first_preds = next(iter(preds_dict.values()))
|
|
156
|
+
if np.asarray(first_preds).ndim == 1:
|
|
157
|
+
# Pairwise disagreement
|
|
158
|
+
disagreement = float(1.0 - np.sum(win_shares ** 2))
|
|
159
|
+
else:
|
|
160
|
+
disagreement = None
|
|
161
|
+
|
|
162
|
+
# Validation set quality
|
|
163
|
+
# n_val / k >= 100 = stable neighbourhood estimates.
|
|
164
|
+
val_quality = float(min(1.0, (n_val / k) / 100.0))
|
|
165
|
+
|
|
166
|
+
# KNN learnability
|
|
167
|
+
feature_score = float(np.clip((n_features - 2) / 10.0, 0.0, 1.0))
|
|
168
|
+
knn_learnability = val_quality * feature_score
|
|
169
|
+
|
|
170
|
+
# Recommendation
|
|
171
|
+
|
|
172
|
+
if val_quality < 0.2:
|
|
173
|
+
recommendation = 'UNRELIABLE'
|
|
174
|
+
reason = (
|
|
175
|
+
f"Validation set too small relative to k "
|
|
176
|
+
f"(n_val/k = {n_val/k:.0f}, recommended \u2265 100). "
|
|
177
|
+
f"Reduce k or increase the validation set before drawing conclusions."
|
|
178
|
+
)
|
|
179
|
+
elif local_uplift < 0.01 and regional_diversity < 0.4:
|
|
180
|
+
recommendation = 'SKIP DES'
|
|
181
|
+
reason = (
|
|
182
|
+
f"Local uplift is negligible ({local_uplift*100:.1f}%) and regional "
|
|
183
|
+
f"diversity is low ({regional_diversity:.2f}). The routing signal is "
|
|
184
|
+
f"too weak to improve on the best single model. Use it directly."
|
|
185
|
+
)
|
|
186
|
+
elif knn_learnability < 0.25 or local_uplift < 0.05 or regional_diversity < 0.45:
|
|
187
|
+
weak_reasons = []
|
|
188
|
+
if knn_learnability < 0.25:
|
|
189
|
+
weak_reasons.append(
|
|
190
|
+
f"KNN learnability is low ({knn_learnability:.2f}) — "
|
|
191
|
+
f"{'the feature space is too small' if feature_score < 0.5 else 'the val set is too sparse'} "
|
|
192
|
+
f"for stable competence regions"
|
|
193
|
+
)
|
|
194
|
+
if local_uplift < 0.05:
|
|
195
|
+
weak_reasons.append(f"local uplift is modest ({local_uplift*100:.1f}%)")
|
|
196
|
+
if regional_diversity < 0.45:
|
|
197
|
+
weak_reasons.append(f"regional diversity is low ({regional_diversity:.2f})")
|
|
198
|
+
recommendation = 'MAYBE \u2014 try Global Ensemble first'
|
|
199
|
+
reason = (
|
|
200
|
+
f"{'; '.join(weak_reasons).capitalize()}. "
|
|
201
|
+
f"A fixed-weight global ensemble may capture most of the gain "
|
|
202
|
+
f"with less variance. Run both and compare on held-out test data."
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
recommendation = 'USE DES'
|
|
206
|
+
reason = (
|
|
207
|
+
f"Local uplift ({local_uplift*100:.1f}%), regional diversity "
|
|
208
|
+
f"({regional_diversity:.2f}), and KNN learnability ({knn_learnability:.2f}) "
|
|
209
|
+
f"are all strong. DES is likely to improve on both the best single "
|
|
210
|
+
f"model and the global ensemble."
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Assemble results
|
|
214
|
+
global_best_win_share = float(model_wins[best_model])
|
|
215
|
+
|
|
216
|
+
result = {
|
|
217
|
+
'n_val': n_val,
|
|
218
|
+
'n_features': n_features,
|
|
219
|
+
'n_models': n_models,
|
|
220
|
+
'k': k,
|
|
221
|
+
'model_scores': model_scores,
|
|
222
|
+
'best_model': best_model,
|
|
223
|
+
'oracle_gain': oracle_gain,
|
|
224
|
+
'estimated_gain': estimated_gain,
|
|
225
|
+
'local_uplift': local_uplift,
|
|
226
|
+
'regional_diversity': regional_diversity,
|
|
227
|
+
'global_best_win_share': global_best_win_share,
|
|
228
|
+
'model_win_shares': model_wins,
|
|
229
|
+
'disagreement': disagreement,
|
|
230
|
+
'val_quality': val_quality,
|
|
231
|
+
'feature_score': feature_score,
|
|
232
|
+
'knn_learnability': knn_learnability,
|
|
233
|
+
'recommendation': recommendation,
|
|
234
|
+
'reason': reason,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
if verbose:
|
|
238
|
+
_print_report(result)
|
|
239
|
+
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
# Report formatting
|
|
243
|
+
|
|
244
|
+
def _print_report(r):
|
|
245
|
+
W = 72
|
|
246
|
+
|
|
247
|
+
def _section(title):
|
|
248
|
+
print(f"\n {title}")
|
|
249
|
+
print(f" {'─' * (W - 4)}")
|
|
250
|
+
|
|
251
|
+
print(f"\n {'━' * W}")
|
|
252
|
+
print(f" DES Suitability Analysis")
|
|
253
|
+
print(f" {'━' * W}")
|
|
254
|
+
print(f" {r['n_val']:,} val samples · {r['n_features']} features · "
|
|
255
|
+
f"{r['n_models']} models · k = {r['k']}")
|
|
256
|
+
|
|
257
|
+
# Model scores
|
|
258
|
+
_section("Model scores (higher-is-better scale)")
|
|
259
|
+
scores = list(r['model_scores'].values())
|
|
260
|
+
s_min = min(scores)
|
|
261
|
+
s_max = max(scores)
|
|
262
|
+
s_range = s_max - s_min
|
|
263
|
+
for name, score in r['model_scores'].items():
|
|
264
|
+
marker = ' ← best' if name == r['best_model'] else ''
|
|
265
|
+
# Bar shows each model's relative position in the pool score range.
|
|
266
|
+
bar_val = (score - s_min) / s_range if s_range > 0 else 1.0
|
|
267
|
+
bar = _bar(bar_val)
|
|
268
|
+
print(f" {name:<22} {score:+.4f} {bar}{marker}")
|
|
269
|
+
|
|
270
|
+
# Oracle gain
|
|
271
|
+
_section("Oracle / local uplift (headroom and realisable gain)")
|
|
272
|
+
og = r['oracle_gain']
|
|
273
|
+
lu = r['local_uplift']
|
|
274
|
+
eg = r['estimated_gain']
|
|
275
|
+
bar_og = _bar(min(og / 0.20, 1.0))
|
|
276
|
+
bar_lu = _bar(min(lu / 0.10, 1.0))
|
|
277
|
+
print(f" Oracle gain {og*100:+6.2f}% {bar_og}")
|
|
278
|
+
print(f" Local uplift {lu*100:+6.2f}% {bar_lu}")
|
|
279
|
+
print(f" Estimated DES {eg*100:+6.2f}% (≤2% of oracle; empirical range 1–5%)")
|
|
280
|
+
print()
|
|
281
|
+
if og < 0.02:
|
|
282
|
+
note = "Very little headroom — models already agree on most samples."
|
|
283
|
+
elif og < 0.05:
|
|
284
|
+
note = "Moderate headroom — DES may help, depends on regional structure."
|
|
285
|
+
elif og < 0.12:
|
|
286
|
+
note = "Good headroom — DES has meaningful potential."
|
|
287
|
+
else:
|
|
288
|
+
note = "Large headroom — strong case for DES."
|
|
289
|
+
print(f" Oracle gain is the per-sample ceiling. Local uplift is the")
|
|
290
|
+
print(f" neighbourhood-level routing advantage that must generalise to")
|
|
291
|
+
print(f" test data. Low local uplift with high oracle gain usually means")
|
|
292
|
+
print(f" the routing signal is noisy (val set too small or too few features).")
|
|
293
|
+
|
|
294
|
+
# Regional diversity
|
|
295
|
+
_section("Regional diversity (do different models win in different areas?)")
|
|
296
|
+
rd = r['regional_diversity']
|
|
297
|
+
bar_rd = _bar(rd)
|
|
298
|
+
print(f" Diversity score {rd:.3f} {bar_rd}")
|
|
299
|
+
print()
|
|
300
|
+
print(f" {'Model':<22} {'Win share':>10} {'Neighbourhoods'}")
|
|
301
|
+
print(f" {'─'*22} {'─'*10} {'─'*14}")
|
|
302
|
+
for name, share in sorted(r['model_win_shares'].items(),
|
|
303
|
+
key=lambda x: -x[1]):
|
|
304
|
+
bar = _bar(share, width=20)
|
|
305
|
+
print(f" {name:<22} {share*100:>9.1f}% {bar}")
|
|
306
|
+
print()
|
|
307
|
+
if rd < 0.35:
|
|
308
|
+
note = "Low diversity — one model dominates most regions."
|
|
309
|
+
elif rd < 0.65:
|
|
310
|
+
note = "Moderate diversity — some regional structure present."
|
|
311
|
+
else:
|
|
312
|
+
note = "High diversity — models have distinct regional strengths."
|
|
313
|
+
print(f" {note}")
|
|
314
|
+
# Warn when the local neighbourhood winner differs from the global best.
|
|
315
|
+
gbws = r['global_best_win_share']
|
|
316
|
+
local_leader = max(r['model_win_shares'], key=r['model_win_shares'].get)
|
|
317
|
+
if local_leader != r['best_model']:
|
|
318
|
+
share_leader = r['model_win_shares'][local_leader] * 100
|
|
319
|
+
print()
|
|
320
|
+
print(f" ⚠ Local winner '{local_leader}' ({share_leader:.0f}% of regions)")
|
|
321
|
+
print(f" differs from global best '{r['best_model']}' ({gbws*100:.0f}%).")
|
|
322
|
+
print(f" A weaker model dominating locally often signals noisy")
|
|
323
|
+
print(f" neighbourhood estimates. Cross-check local uplift.")
|
|
324
|
+
|
|
325
|
+
# Disagreement
|
|
326
|
+
if r['disagreement'] is not None:
|
|
327
|
+
_section("Model disagreement (how often does the locally-best model vary?)")
|
|
328
|
+
d = r['disagreement']
|
|
329
|
+
bar_d = _bar(d)
|
|
330
|
+
print(f" Disagreement {d:.3f} {bar_d}")
|
|
331
|
+
print()
|
|
332
|
+
if d < 0.3:
|
|
333
|
+
note = "Low disagreement — models make similar errors."
|
|
334
|
+
elif d < 0.6:
|
|
335
|
+
note = "Moderate disagreement — routing has something to work with."
|
|
336
|
+
else:
|
|
337
|
+
note = "High disagreement — strong routing signal available."
|
|
338
|
+
print(f" {note}")
|
|
339
|
+
|
|
340
|
+
# Validation set quality & KNN learnability
|
|
341
|
+
_section("KNN learnability (can the router learn stable competence regions?)")
|
|
342
|
+
vq = r['val_quality']
|
|
343
|
+
fs = r['feature_score']
|
|
344
|
+
kl = r['knn_learnability']
|
|
345
|
+
ratio = r['n_val'] / r['k']
|
|
346
|
+
bar_vq = _bar(vq)
|
|
347
|
+
bar_fs = _bar(fs)
|
|
348
|
+
bar_kl = _bar(kl)
|
|
349
|
+
vq_status = '✓' if ratio >= 100 else ('~' if ratio >= 50 else '✗')
|
|
350
|
+
fs_status = '✓' if r['n_features'] >= 12 else ('~' if r['n_features'] >= 6 else '✗')
|
|
351
|
+
kl_status = '✓' if kl >= 0.5 else ('~' if kl >= 0.25 else '✗')
|
|
352
|
+
print(f" n_val / k = {ratio:>5.0f} {bar_vq} {vq_status} (sample density)")
|
|
353
|
+
print(f" n_features = {r['n_features']:>5d} {bar_fs} {fs_status} (distance structure)")
|
|
354
|
+
print(f" learnability = {kl:.2f} {bar_kl} {kl_status} (combined)")
|
|
355
|
+
print()
|
|
356
|
+
if kl < 0.25:
|
|
357
|
+
note = ("Low — KNN is unlikely to find stable competence regions. "
|
|
358
|
+
"Consider Global Ensemble instead.")
|
|
359
|
+
elif kl < 0.5:
|
|
360
|
+
note = "Moderate — competence regions may be noisy. Treat results with caution."
|
|
361
|
+
else:
|
|
362
|
+
note = "Good — KNN has enough data and feature structure to route reliably."
|
|
363
|
+
print(f" {note}")
|
|
364
|
+
|
|
365
|
+
# Recommendation
|
|
366
|
+
print(f"\n {'━' * W}")
|
|
367
|
+
rec = r['recommendation']
|
|
368
|
+
symbols = {
|
|
369
|
+
'USE DES': '✓',
|
|
370
|
+
'MAYBE — try Global Ensemble first': '~',
|
|
371
|
+
'SKIP DES': '✗',
|
|
372
|
+
'UNRELIABLE': '?',
|
|
373
|
+
}
|
|
374
|
+
symbol = symbols.get(rec, ' ')
|
|
375
|
+
print(f" Recommendation: [{symbol}] {rec}")
|
|
376
|
+
print(f"\n {r['reason']}")
|
|
377
|
+
print(f" {'━' * W}\n")
|
deskit/base/__init__.py
ADDED
deskit/base/base.py
ADDED
deskit/base/knnbase.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from deskit.base.base import BaseRouter
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class KNNBase(BaseRouter):
|
|
6
|
+
"""
|
|
7
|
+
Base for KNN-based DES algorithms.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, metric, mode='max', neighbor_finder=None):
|
|
11
|
+
"""
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
metric : callable
|
|
15
|
+
Per-sample scoring function: (y_true, y_pred) -> float.
|
|
16
|
+
mode : str
|
|
17
|
+
'max' if higher scores are better, 'min' if lower.
|
|
18
|
+
neighbor_finder : NeighborFinder
|
|
19
|
+
Backend used for neighborhood queries.
|
|
20
|
+
"""
|
|
21
|
+
self.metric = metric
|
|
22
|
+
self.mode = mode
|
|
23
|
+
self.model = neighbor_finder
|
|
24
|
+
self.matrix = None # (n_val, n_models); higher is always better
|
|
25
|
+
self.models = None # ordered list of model names
|
|
26
|
+
|
|
27
|
+
def _compute_scores(self, y, preds):
|
|
28
|
+
"""
|
|
29
|
+
Return a 1D array of per-sample metric scores.
|
|
30
|
+
|
|
31
|
+
preds may be 1D (scalar predictions) or 2D (probability arrays, one
|
|
32
|
+
row per sample)
|
|
33
|
+
"""
|
|
34
|
+
preds = np.asarray(preds)
|
|
35
|
+
if preds.ndim == 2:
|
|
36
|
+
return np.array([self.metric(y[i], preds[i]) for i in range(len(y))])
|
|
37
|
+
return np.vectorize(self.metric)(y, preds)
|
|
38
|
+
|
|
39
|
+
def fit(self, features, y, preds_dict):
|
|
40
|
+
"""
|
|
41
|
+
Build the score matrix and fit the neighbor index.
|
|
42
|
+
|
|
43
|
+
This method expects pre-validated numpy arrays.
|
|
44
|
+
"""
|
|
45
|
+
self.models = list(preds_dict.keys())
|
|
46
|
+
n_val = len(y)
|
|
47
|
+
n_models = len(self.models)
|
|
48
|
+
self.matrix = np.zeros((n_val, n_models))
|
|
49
|
+
|
|
50
|
+
for j, name in enumerate(self.models):
|
|
51
|
+
scores = self._compute_scores(y, preds_dict[name])
|
|
52
|
+
self.matrix[:, j] = scores if self.mode == 'max' else -scores
|
|
53
|
+
|
|
54
|
+
self.model.fit(features)
|
deskit/des/__init__.py
ADDED