datasci-toolkit 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. datasci_toolkit-0.1.0/.github/workflows/docs.yml +48 -0
  2. datasci_toolkit-0.1.0/.github/workflows/publish.yml +24 -0
  3. datasci_toolkit-0.1.0/.gitignore +83 -0
  4. datasci_toolkit-0.1.0/.python-version +1 -0
  5. datasci_toolkit-0.1.0/PKG-INFO +13 -0
  6. datasci_toolkit-0.1.0/README.md +67 -0
  7. datasci_toolkit-0.1.0/datasci_toolkit/__init__.py +33 -0
  8. datasci_toolkit-0.1.0/datasci_toolkit/bin_editor.py +373 -0
  9. datasci_toolkit-0.1.0/datasci_toolkit/bin_editor_widget.py +309 -0
  10. datasci_toolkit-0.1.0/datasci_toolkit/grouping.py +474 -0
  11. datasci_toolkit-0.1.0/datasci_toolkit/label_imputation.py +145 -0
  12. datasci_toolkit-0.1.0/datasci_toolkit/metrics.py +214 -0
  13. datasci_toolkit-0.1.0/datasci_toolkit/model_selection.py +418 -0
  14. datasci_toolkit-0.1.0/datasci_toolkit/stability.py +263 -0
  15. datasci_toolkit-0.1.0/datasci_toolkit/variable_clustering.py +109 -0
  16. datasci_toolkit-0.1.0/docs/api/bin_editor.md +9 -0
  17. datasci_toolkit-0.1.0/docs/api/grouping.md +9 -0
  18. datasci_toolkit-0.1.0/docs/api/label_imputation.md +9 -0
  19. datasci_toolkit-0.1.0/docs/api/metrics.md +39 -0
  20. datasci_toolkit-0.1.0/docs/api/model_selection.md +5 -0
  21. datasci_toolkit-0.1.0/docs/api/stability.md +21 -0
  22. datasci_toolkit-0.1.0/docs/api/variable_clustering.md +5 -0
  23. datasci_toolkit-0.1.0/docs/index.md +45 -0
  24. datasci_toolkit-0.1.0/docs/tutorials/bin_editor.md +123 -0
  25. datasci_toolkit-0.1.0/docs/tutorials/grouping.md +128 -0
  26. datasci_toolkit-0.1.0/docs/tutorials/label_imputation.md +103 -0
  27. datasci_toolkit-0.1.0/docs/tutorials/metrics.md +108 -0
  28. datasci_toolkit-0.1.0/docs/tutorials/model_selection.md +106 -0
  29. datasci_toolkit-0.1.0/docs/tutorials/stability.md +136 -0
  30. datasci_toolkit-0.1.0/docs/tutorials/variable_clustering.md +113 -0
  31. datasci_toolkit-0.1.0/examples/01_stability.py +162 -0
  32. datasci_toolkit-0.1.0/examples/02_grouping.py +160 -0
  33. datasci_toolkit-0.1.0/examples/03_metrics.py +161 -0
  34. datasci_toolkit-0.1.0/examples/04_model_selection.py +184 -0
  35. datasci_toolkit-0.1.0/examples/05_label_imputation.py +169 -0
  36. datasci_toolkit-0.1.0/examples/06_bin_editor.py +205 -0
  37. datasci_toolkit-0.1.0/examples/07_variable_clustering.py +186 -0
  38. datasci_toolkit-0.1.0/mkdocs.yml +62 -0
  39. datasci_toolkit-0.1.0/pyproject.toml +43 -0
  40. datasci_toolkit-0.1.0/scripts/gen_index.py +44 -0
  41. datasci_toolkit-0.1.0/tests/__init__.py +0 -0
  42. datasci_toolkit-0.1.0/tests/test_bin_editor.py +440 -0
  43. datasci_toolkit-0.1.0/tests/test_bin_editor_widget.py +255 -0
  44. datasci_toolkit-0.1.0/tests/test_grouping.py +377 -0
  45. datasci_toolkit-0.1.0/tests/test_label_imputation.py +280 -0
  46. datasci_toolkit-0.1.0/tests/test_metrics.py +314 -0
  47. datasci_toolkit-0.1.0/tests/test_model_selection.py +258 -0
  48. datasci_toolkit-0.1.0/tests/test_stability.py +258 -0
  49. datasci_toolkit-0.1.0/tests/test_variable_clustering.py +161 -0
  50. datasci_toolkit-0.1.0/uv.lock +2263 -0
@@ -0,0 +1,48 @@
1
+ name: Build and deploy docs
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ workflow_dispatch:
7
+
8
+ permissions:
9
+ contents: read
10
+ pages: write
11
+ id-token: write
12
+
13
+ concurrency:
14
+ group: pages
15
+ cancel-in-progress: false
16
+
17
+ jobs:
18
+ build:
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - uses: actions/checkout@v4
22
+
23
+ - uses: actions/setup-python@v5
24
+ with:
25
+ python-version: "3.12"
26
+
27
+ - name: Install uv
28
+ run: pip install uv
29
+
30
+ - name: Install dependencies
31
+ run: uv sync --group dev
32
+
33
+ - name: Build docs
34
+ run: uv run mkdocs build --strict --site-dir _site
35
+
36
+ - uses: actions/upload-pages-artifact@v3
37
+ with:
38
+ path: _site
39
+
40
+ deploy:
41
+ needs: build
42
+ runs-on: ubuntu-latest
43
+ environment:
44
+ name: github-pages
45
+ url: ${{ steps.deployment.outputs.page_url }}
46
+ steps:
47
+ - uses: actions/deploy-pages@v4
48
+ id: deployment
@@ -0,0 +1,24 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ publish:
10
+ runs-on: ubuntu-latest
11
+ environment: pypi
12
+ permissions:
13
+ id-token: write
14
+
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+
18
+ - uses: astral-sh/setup-uv@v5
19
+
20
+ - name: Build
21
+ run: uv build
22
+
23
+ - name: Publish
24
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,83 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ *.so
7
+ *.egg
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+ eggs/
12
+ parts/
13
+ var/
14
+ sdist/
15
+ wheels/
16
+ *.egg-link
17
+ .installed.cfg
18
+ lib/
19
+ lib64/
20
+
21
+ # Virtual environments
22
+ .venv/
23
+ venv/
24
+ env/
25
+ ENV/
26
+
27
+ # Jupyter
28
+ .ipynb_checkpoints/
29
+ *.ipynb_checkpoints
30
+
31
+ # Distribution / packaging
32
+ MANIFEST
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Testing
37
+ .tox/
38
+ .coverage
39
+ .coverage.*
40
+ .cache
41
+ .pytest_cache/
42
+ htmlcov/
43
+ nosetests.xml
44
+ coverage.xml
45
+ *.cover
46
+
47
+ # Type checking
48
+ .mypy_cache/
49
+ .dmypy.json
50
+ dmypy.json
51
+ .pytype/
52
+
53
+ # IDEs
54
+ .idea/
55
+ .vscode/
56
+ *.swp
57
+ *.swo
58
+ *~
59
+
60
+ # OS
61
+ .DS_Store
62
+ Thumbs.db
63
+
64
+ # Environment variables
65
+ .env
66
+ .env.*
67
+
68
+ # Data
69
+ *.csv
70
+ *.parquet
71
+ *.feather
72
+ *.pkl
73
+ *.pickle
74
+ *.h5
75
+ *.hdf5
76
+ data/
77
+
78
+ # Source originals and extraction notes
79
+ src/
80
+ EXTRACTABLE_COMPONENTS.md
81
+ CLAUDE.md
82
+ examples/__marimo__
83
+ _site/
@@ -0,0 +1 @@
1
+ 3.12
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.4
2
+ Name: datasci-toolkit
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: anywidget>=0.9.21
7
+ Requires-Dist: lightgbm>=4.6.0
8
+ Requires-Dist: matplotlib>=3.10.8
9
+ Requires-Dist: numpy>=2.4.3
10
+ Requires-Dist: optbinning>=0.21.0
11
+ Requires-Dist: polars>=1.39.3
12
+ Requires-Dist: scikit-learn>=1.8.0
13
+ Requires-Dist: scipy>=1.17.1
@@ -0,0 +1,67 @@
1
+ # datasci-toolkit
2
+
3
+ My personal Python toolkit for data science — a clean rewrite of tools I use day-to-day for binary classification, scorecard development, and model validation.
4
+
5
+ Polars-native, sklearn-compatible, zero external state.
6
+
7
+ ## Modules
8
+
9
+ | Module | Classes / Functions | Description |
10
+ |---|---|---|
11
+ | `stability` | `PSI`, `ESI`, `StabilityMonitor` | Population and event stability indices |
12
+ | `grouping` | `StabilityGrouping`, `WOETransformer` | Stability-constrained optimal binning and WOE encoding |
13
+ | `metrics` | `gini`, `ks`, `lift`, `iv`, `BootstrapGini`, `feature_power`, `gini_by_period`, `lift_by_period`, `plot_metric_by_period` | Binary classification metrics with period breakdowns |
14
+ | `model_selection` | `AUCStepwiseLogit` | Gini-based stepwise logistic regression |
15
+ | `label_imputation` | `KNNLabelImputer`, `TargetImputer` | KNN imputation for records with missing labels |
16
+ | `bin_editor` | `BinEditor`, `BinEditorWidget` | Headless and interactive bin boundary editor |
17
+ | `variable_clustering` | `CorrVarClus` | Hierarchical correlation clustering for variable reduction |
18
+
19
+ ## Quick start
20
+
21
+ ```python
22
+ import polars as pl
23
+ from datasci_toolkit import StabilityGrouping, AUCStepwiseLogit, CorrVarClus
24
+
25
+ # 1. Stability-constrained binning
26
+ sg = StabilityGrouping(stability_threshold=0.1).fit(
27
+ X_train, y_train, t_train=month_train,
28
+ X_val=X_val, y_val=y_val, t_val=month_val,
29
+ )
30
+ X_woe = sg.transform(X_test)
31
+
32
+ # 2. Remove correlated features
33
+ cc = CorrVarClus(max_correlation=0.5).fit(X_woe, y_train)
34
+ features = cc.best_features()
35
+
36
+ # 3. Stepwise selection
37
+ model = AUCStepwiseLogit(max_predictors=10, max_correlation=0.8).fit(
38
+ X_woe.select(features), y_train,
39
+ X_val=X_val_woe.select(features), y_val=y_val,
40
+ )
41
+ ```
42
+
43
+ ## Install
44
+
45
+ ```bash
46
+ pip install datasci-toolkit
47
+ ```
48
+
49
+ ## Documentation
50
+
51
+ **[detrin.github.io/datasci-toolkit](https://detrin.github.io/datasci-toolkit)**
52
+
53
+ | Notebook | Topic |
54
+ |---|---|
55
+ | [01 Stability](https://detrin.github.io/datasci-toolkit/tutorials/stability/) | PSI drift detection, StabilityMonitor, ESI |
56
+ | [02 Grouping](https://detrin.github.io/datasci-toolkit/tutorials/grouping/) | StabilityGrouping, WOETransformer |
57
+ | [03 Metrics](https://detrin.github.io/datasci-toolkit/tutorials/metrics/) | Gini, KS, lift, IV, bootstrap CI, period breakdowns |
58
+ | [04 Model selection](https://detrin.github.io/datasci-toolkit/tutorials/model_selection/) | AUCStepwiseLogit, correlation filter, CV mode |
59
+ | [05 Label imputation](https://detrin.github.io/datasci-toolkit/tutorials/label_imputation/) | KNNLabelImputer, TargetImputer |
60
+ | [06 Bin editor](https://detrin.github.io/datasci-toolkit/tutorials/bin_editor/) | BinEditor headless API, BinEditorWidget |
61
+ | [07 Variable clustering](https://detrin.github.io/datasci-toolkit/tutorials/variable_clustering/) | CorrVarClus dendrogram, best_features |
62
+
63
+ ## Stack
64
+
65
+ - Python 3.12, `polars` — no pandas
66
+ - `scikit-learn` for estimator conventions
67
+ - `matplotlib` for standalone plot functions
@@ -0,0 +1,33 @@
1
+ from datasci_toolkit.bin_editor import BinEditor
2
+ from datasci_toolkit.bin_editor_widget import BinEditorWidget
3
+ from datasci_toolkit.grouping import StabilityGrouping, WOETransformer
4
+ from datasci_toolkit.metrics import BootstrapGini, feature_power, gini, gini_by_period, iv, ks, lift, lift_by_period, plot_metric_by_period
5
+ from datasci_toolkit.model_selection import AUCStepwiseLogit
6
+ from datasci_toolkit.variable_clustering import CorrVarClus
7
+ from datasci_toolkit.label_imputation import KNNLabelImputer, TargetImputer
8
+ from datasci_toolkit.stability import ESI, PSI, StabilityMonitor, plot_psi_comparison, psi_hist
9
+
10
+ __all__ = [
11
+ "PSI",
12
+ "ESI",
13
+ "StabilityMonitor",
14
+ "plot_psi_comparison",
15
+ "psi_hist",
16
+ "WOETransformer",
17
+ "StabilityGrouping",
18
+ "AUCStepwiseLogit",
19
+ "gini",
20
+ "ks",
21
+ "lift",
22
+ "iv",
23
+ "BootstrapGini",
24
+ "feature_power",
25
+ "TargetImputer",
26
+ "KNNLabelImputer",
27
+ "BinEditor",
28
+ "BinEditorWidget",
29
+ "CorrVarClus",
30
+ "gini_by_period",
31
+ "lift_by_period",
32
+ "plot_metric_by_period",
33
+ ]
@@ -0,0 +1,373 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import polars as pl
10
+
11
+ from datasci_toolkit.grouping import _rsi
12
+
13
+ _SMOOTH = 0.5
14
+
15
+
16
+ class FeatureDtype(str, Enum):
17
+ NUMERIC = "float"
18
+ CATEGORICAL = "category"
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class BinStats:
23
+ counts: np.ndarray
24
+ event_rates: np.ndarray
25
+ woe: np.ndarray
26
+ iv: float
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class TemporalStats:
31
+ months: list[Any]
32
+ rsi: float
33
+ event_rates: list[list[float | None]]
34
+ pop_shares: list[list[float]]
35
+
36
+
37
+ @dataclass
38
+ class FeatureState:
39
+ feature: str
40
+ dtype: FeatureDtype
41
+ n_bins: int
42
+ bins: list[str] | dict[str, int]
43
+ counts: list[float]
44
+ event_rates: list[float | None]
45
+ woe: list[float]
46
+ iv: float
47
+ splits: list[float] | None = None
48
+ groups: dict[int, list[str]] | None = None
49
+ temporal: TemporalStats | None = None
50
+
51
+
52
+ def _bin_stats(target: np.ndarray, weights: np.ndarray, assignments: np.ndarray, n_bins: int) -> BinStats:
53
+ total_ev = float((target * weights).sum())
54
+ total_nev = float(((1.0 - target) * weights).sum())
55
+ yw = target * weights
56
+ counts = np.bincount(assignments, weights=weights, minlength=n_bins + 1).astype(float)
57
+ events = np.bincount(assignments, weights=yw, minlength=n_bins + 1).astype(float)
58
+ nonevents = counts - events
59
+ event_rates = np.where(counts > 0, events / counts, np.nan)
60
+
61
+ events_per_bin, nonevents_per_bin = events[:n_bins], nonevents[:n_bins]
62
+ event_dist = (events_per_bin + _SMOOTH) / (total_ev + _SMOOTH * n_bins)
63
+ nonevent_dist = (nonevents_per_bin + _SMOOTH) / (total_nev + _SMOOTH * n_bins)
64
+ woe_per_bin = np.log(event_dist / nonevent_dist)
65
+ iv = float(((event_dist - nonevent_dist) * woe_per_bin).sum())
66
+
67
+ nan_event_dist = (events[n_bins] + _SMOOTH) / (total_ev + _SMOOTH)
68
+ nan_nonevent_dist = (nonevents[n_bins] + _SMOOTH) / (total_nev + _SMOOTH)
69
+ woe_nan = float(np.log(nan_event_dist / nan_nonevent_dist))
70
+
71
+ return BinStats(counts=counts, event_rates=event_rates, woe=np.append(woe_per_bin, woe_nan), iv=iv)
72
+
73
+
74
+ def _temporal_stats(
75
+ target: np.ndarray,
76
+ weights: np.ndarray,
77
+ assignments: np.ndarray,
78
+ n_bins: int,
79
+ time_periods: np.ndarray,
80
+ threshold: float,
81
+ ) -> TemporalStats:
82
+ months = np.sort(np.unique(time_periods))
83
+ event_rate_by_bin: list[list[float | None]] = [[] for _ in range(n_bins)]
84
+ pop_share_by_bin: list[list[float]] = [[] for _ in range(n_bins)]
85
+
86
+ for month in months:
87
+ mask = time_periods == month
88
+ stats = _bin_stats(target[mask], weights[mask], assignments[mask], n_bins)
89
+ total = float(stats.counts[:n_bins].sum()) or 1.0
90
+ for bin_index in range(n_bins):
91
+ event_rate = stats.event_rates[bin_index]
92
+ event_rate_by_bin[bin_index].append(None if np.isnan(event_rate) else round(float(event_rate), 6))
93
+ pop_share_by_bin[bin_index].append(round(float(stats.counts[bin_index] / total), 6))
94
+
95
+ scores_array: list[float] = []
96
+ rates_array: list[float] = []
97
+ months_array: list[Any] = []
98
+ for month_index, month in enumerate(months):
99
+ for bin_index in range(n_bins):
100
+ event_rate = event_rate_by_bin[bin_index][month_index]
101
+ if event_rate is not None:
102
+ scores_array.append(float(bin_index))
103
+ rates_array.append(event_rate)
104
+ months_array.append(month)
105
+
106
+ rsi = _rsi(np.array(scores_array), np.array(rates_array), np.array(months_array), threshold) if len(scores_array) > 1 else 1.0
107
+
108
+ return TemporalStats(
109
+ months=months.tolist(),
110
+ rsi=round(rsi, 4),
111
+ event_rates=event_rate_by_bin,
112
+ pop_shares=pop_share_by_bin,
113
+ )
114
+
115
+
116
+ def _num_assign(values: np.ndarray, splits: list[float]) -> np.ndarray:
117
+ missing_mask = np.isnan(values)
118
+ assignments = np.digitize(values, splits)
119
+ assignments[missing_mask] = len(splits) + 1
120
+ return assignments
121
+
122
+
123
+ def _cat_assign(values: np.ndarray, category_bins: dict[str, int]) -> np.ndarray:
124
+ n_groups = max(category_bins.values()) + 1 if category_bins else 0
125
+ assignments = np.full(len(values), n_groups, dtype=np.intp)
126
+ for category, group in category_bins.items():
127
+ assignments[values == category] = group
128
+ return assignments
129
+
130
+
131
+ def _num_labels(splits: list[float]) -> list[str]:
132
+ if not splits:
133
+ return ["-inf to inf", "NaN"]
134
+ split_strs = [f"{v:.4g}" for v in splits]
135
+ return [f"-inf to {split_strs[0]}"] + [f"{split_strs[i]} to {split_strs[i+1]}" for i in range(len(split_strs) - 1)] + [f"{split_strs[-1]} to inf", "NaN"]
136
+
137
+
138
+ def _num_state(feat: str, splits: list[float], values: np.ndarray, target: np.ndarray, weights: np.ndarray) -> FeatureState:
139
+ n_bins = len(splits) + 1
140
+ stats = _bin_stats(target, weights, _num_assign(values, splits), n_bins)
141
+ return FeatureState(
142
+ feature=feat,
143
+ dtype=FeatureDtype.NUMERIC,
144
+ n_bins=n_bins,
145
+ splits=list(splits),
146
+ bins=_num_labels(splits),
147
+ counts=stats.counts.tolist(),
148
+ event_rates=[None if np.isnan(v) else round(float(v), 6) for v in stats.event_rates],
149
+ woe=[round(float(v), 6) for v in stats.woe],
150
+ iv=round(stats.iv, 6),
151
+ )
152
+
153
+
154
+ def _cat_state(feat: str, category_bins: dict[str, int], values: np.ndarray, target: np.ndarray, weights: np.ndarray) -> FeatureState:
155
+ n_groups = max(category_bins.values()) + 1 if category_bins else 0
156
+ stats = _bin_stats(target, weights, _cat_assign(values, category_bins), n_groups)
157
+ groups: dict[int, list[str]] = {}
158
+ for cat, grp in category_bins.items():
159
+ groups.setdefault(grp, []).append(str(cat))
160
+ return FeatureState(
161
+ feature=feat,
162
+ dtype=FeatureDtype.CATEGORICAL,
163
+ n_bins=n_groups,
164
+ groups={k: sorted(v) for k, v in groups.items()},
165
+ bins=dict(category_bins),
166
+ counts=stats.counts.tolist(),
167
+ event_rates=[None if np.isnan(v) else round(float(v), 6) for v in stats.event_rates],
168
+ woe=[round(float(v), 6) for v in stats.woe],
169
+ iv=round(stats.iv, 6),
170
+ )
171
+
172
+
173
+ class BinEditor:
174
+ """Headless state machine for editing bin boundaries.
175
+
176
+ Works identically in plain Python scripts, notebooks, and agents. All
177
+ edits are logged per feature with undo support. Call `accept()` to export
178
+ the final bin specs dict for use with `WOETransformer`.
179
+
180
+ Args:
181
+ bin_specs: Initial bin specifications — a dict produced by
182
+ `StabilityGrouping.bin_specs_` or built manually.
183
+ features: Feature DataFrame matching the features in ``bin_specs``.
184
+ target: Binary target series (0/1 or float).
185
+ time_periods: Optional time series for temporal stability metrics.
186
+ weights: Optional sample weight series.
187
+ stability_threshold: RSI threshold used to flag unstable bins in the
188
+ state dict (does not block edits).
189
+
190
+ Note:
191
+ All state is accessible via `state(feat)`, which returns a `FeatureState`
192
+ dataclass with attributes ``bins``, ``n_bins``, ``counts``, ``event_rates``,
193
+ ``woe``, ``iv``, ``dtype``, ``groups``, and ``temporal``.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ bin_specs: dict[str, dict[str, Any]],
199
+ features: pl.DataFrame,
200
+ target: pl.Series,
201
+ time_periods: pl.Series | None = None,
202
+ weights: pl.Series | None = None,
203
+ stability_threshold: float = 0.1,
204
+ ) -> None:
205
+ self._targets = target.cast(pl.Float64).to_numpy()
206
+ self._weights = weights.cast(pl.Float64).to_numpy() if weights is not None else np.ones(len(self._targets))
207
+ self._time: np.ndarray | None = time_periods.to_numpy() if time_periods is not None else None
208
+ self._threshold = stability_threshold
209
+ self._x: dict[str, np.ndarray] = {}
210
+ self._splits: dict[str, list[float]] = {}
211
+ self._cat_bins: dict[str, dict[str, int]] = {}
212
+ self._history: dict[str, list[tuple[str, Any]]] = {}
213
+ self._orig: dict[str, dict[str, Any]] = {}
214
+
215
+ for feat, spec in bin_specs.items():
216
+ if feat not in features.columns:
217
+ continue
218
+ self._orig[feat] = spec
219
+ self._history[feat] = []
220
+ if spec["dtype"] == FeatureDtype.NUMERIC:
221
+ self._x[feat] = features[feat].cast(pl.Float64).to_numpy()
222
+ self._splits[feat] = [float(s) for s in spec["bins"][1:-1] if np.isfinite(s)]
223
+ else:
224
+ self._x[feat] = features[feat].cast(pl.Utf8).to_numpy().astype(str)
225
+ self._cat_bins[feat] = {str(k): int(v) for k, v in spec["bins"].items()}
226
+
227
+ def features(self) -> list[str]:
228
+ return list(self._splits.keys()) + list(self._cat_bins.keys())
229
+
230
+ def _base_state(self, feat: str) -> FeatureState:
231
+ if feat in self._splits:
232
+ return _num_state(feat, self._splits[feat], self._x[feat], self._targets, self._weights)
233
+ return _cat_state(feat, self._cat_bins[feat], self._x[feat], self._targets, self._weights)
234
+
235
+ def _assignments(self, feat: str) -> np.ndarray:
236
+ if feat in self._splits:
237
+ return _num_assign(self._x[feat], self._splits[feat])
238
+ return _cat_assign(self._x[feat], self._cat_bins[feat])
239
+
240
+ def state(self, feat: str) -> FeatureState:
241
+ s = self._base_state(feat)
242
+ if self._time is not None:
243
+ s.temporal = _temporal_stats(
244
+ self._targets, self._weights, self._assignments(feat), s.n_bins, self._time, self._threshold
245
+ )
246
+ return s
247
+
248
+ def _push(self, feat: str) -> None:
249
+ if feat in self._splits:
250
+ self._history[feat].append(("splits", list(self._splits[feat])))
251
+ else:
252
+ self._history[feat].append(("cat", copy.deepcopy(self._cat_bins[feat])))
253
+
254
+ def split(self, feat: str, value: float) -> FeatureState:
255
+ if value in self._splits[feat]:
256
+ return self.state(feat)
257
+ self._push(feat)
258
+ self._splits[feat] = sorted(self._splits[feat] + [value])
259
+ return self.state(feat)
260
+
261
+ def merge(self, feat: str, bin_idx: int) -> FeatureState:
262
+ if feat in self._splits:
263
+ splits = self._splits[feat]
264
+ if bin_idx >= len(splits):
265
+ return self.state(feat)
266
+ self._push(feat)
267
+ self._splits[feat] = [s for i, s in enumerate(splits) if i != bin_idx]
268
+ else:
269
+ cat_bins = self._cat_bins[feat]
270
+ n_groups = max(cat_bins.values()) + 1 if cat_bins else 0
271
+ if bin_idx >= n_groups - 1:
272
+ return self.state(feat)
273
+ self._push(feat)
274
+ self._cat_bins[feat] = {
275
+ cat: (bin_idx if grp == bin_idx + 1 else (grp - 1 if grp > bin_idx + 1 else grp))
276
+ for cat, grp in cat_bins.items()
277
+ }
278
+ return self.state(feat)
279
+
280
+ def move_boundary(self, feat: str, bin_idx: int, new_value: float) -> FeatureState:
281
+ splits = self._splits[feat]
282
+ if bin_idx >= len(splits):
283
+ return self.state(feat)
284
+ self._push(feat)
285
+ new = list(splits)
286
+ new[bin_idx] = new_value
287
+ self._splits[feat] = sorted(set(new))
288
+ return self.state(feat)
289
+
290
+ def reset(self, feat: str) -> FeatureState:
291
+ self._history[feat] = []
292
+ spec = self._orig[feat]
293
+ if spec["dtype"] == FeatureDtype.NUMERIC:
294
+ self._splits[feat] = [float(s) for s in spec["bins"][1:-1] if np.isfinite(s)]
295
+ else:
296
+ self._cat_bins[feat] = {str(k): int(v) for k, v in spec["bins"].items()}
297
+ return self.state(feat)
298
+
299
+ def undo(self, feat: str) -> FeatureState:
300
+ if not self._history[feat]:
301
+ return self.state(feat)
302
+ kind, prev = self._history[feat].pop()
303
+ if kind == "splits":
304
+ self._splits[feat] = prev
305
+ else:
306
+ self._cat_bins[feat] = prev
307
+ return self.state(feat)
308
+
309
+ def history(self, feat: str) -> list[dict[str, Any]]:
310
+ return [{"type": k, "value": v} for k, v in self._history[feat]]
311
+
312
+ def _suggest_num(self, feat: str, n_suggestions: int) -> list[float]:
313
+ values = self._x[feat]
314
+ x_valid = values[~np.isnan(values)]
315
+ if len(x_valid) == 0:
316
+ return []
317
+ current = self._splits[feat]
318
+ span = float(x_valid.max() - x_valid.min())
319
+ min_gap = span * 0.01
320
+ candidates = [
321
+ float(candidate) for candidate in np.unique(np.percentile(x_valid, np.linspace(5, 95, 40)))
322
+ if all(abs(candidate - split) > min_gap for split in current)
323
+ ]
324
+ base_information_value = self._base_state(feat).iv
325
+ pairs: list[tuple[float, float]] = sorted(
326
+ [
327
+ (
328
+ _bin_stats(self._targets, self._weights, _num_assign(values, sorted(current + [candidate])), len(current) + 2).iv - base_information_value,
329
+ float(candidate),
330
+ )
331
+ for candidate in candidates
332
+ ],
333
+ reverse=True,
334
+ )
335
+ return [v for _, v in pairs[:n_suggestions]]
336
+
337
+ def _suggest_cat(self, feat: str, n_suggestions: int) -> list[tuple[int, int]]:
338
+ category_bins = self._cat_bins[feat]
339
+ n_groups = max(category_bins.values()) + 1 if category_bins else 0
340
+ if n_groups <= 1:
341
+ return []
342
+ values = self._x[feat]
343
+ base_information_value = self._base_state(feat).iv
344
+ pairs: list[tuple[float, tuple[int, int]]] = sorted(
345
+ [
346
+ (
347
+ base_information_value - _bin_stats(
348
+ self._targets, self._weights,
349
+ _cat_assign(values, {
350
+ category: (bin_idx if group == bin_idx + 1 else (group - 1 if group > bin_idx + 1 else group))
351
+ for category, group in category_bins.items()
352
+ }),
353
+ n_groups - 1,
354
+ ).iv,
355
+ (bin_idx, bin_idx + 1),
356
+ )
357
+ for bin_idx in range(n_groups - 1)
358
+ ]
359
+ )
360
+ return [pair for _, pair in pairs[:n_suggestions]]
361
+
362
+ def suggest_splits(self, feat: str, n: int = 5) -> list: # type: ignore[type-arg]
363
+ if feat in self._splits:
364
+ return self._suggest_num(feat, n)
365
+ return self._suggest_cat(feat, n)
366
+
367
+ def accept(self) -> dict[str, dict[str, Any]]:
368
+ return {feat: self.accept_feature(feat) for feat in self.features()}
369
+
370
+ def accept_feature(self, feat: str) -> dict[str, Any]:
371
+ if feat in self._splits:
372
+ return {"dtype": "float", "bins": [-np.inf] + self._splits[feat] + [np.inf]}
373
+ return {"dtype": "category", "bins": dict(self._cat_bins[feat])}