dataforge-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataforge_ml-0.1.0.dist-info/METADATA +34 -0
- dataforge_ml-0.1.0.dist-info/RECORD +54 -0
- dataforge_ml-0.1.0.dist-info/WHEEL +5 -0
- dataforge_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
- dataforge_ml-0.1.0.dist-info/top_level.txt +5 -0
- models/__init__.py +0 -0
- models/_data_structure.py +7 -0
- models/_data_types.py +12 -0
- profiling/__init__.py +35 -0
- profiling/_base.py +101 -0
- profiling/_boolean_config.py +37 -0
- profiling/_boolean_profiler.py +191 -0
- profiling/_categorical.py +315 -0
- profiling/_categorical_config.py +87 -0
- profiling/_correlation_config.py +225 -0
- profiling/_correlation_profiler.py +544 -0
- profiling/_datetime_config.py +98 -0
- profiling/_datetime_profiler.py +406 -0
- profiling/_missingness_config.py +137 -0
- profiling/_missingness_profiler.py +252 -0
- profiling/_numeric_config.py +116 -0
- profiling/_numeric_profiler.py +403 -0
- profiling/_tabular.py +249 -0
- profiling/_target_config.py +74 -0
- profiling/_target_profiler.py +156 -0
- profiling/_text_config.py +40 -0
- profiling/_text_profiler.py +194 -0
- profiling/_type_detector.py +463 -0
- profiling/config.py +236 -0
- profiling/structural.py +280 -0
- splitting/__init__.py +4 -0
- splitting/_config.py +56 -0
- splitting/_splitter.py +202 -0
- tests/__init__.py +0 -0
- tests/conftest.py +7 -0
- tests/integration/__init__.py +0 -0
- tests/integration/conftest.py +82 -0
- tests/integration/test_structural_end_to_end.py +219 -0
- tests/unit/__init__.py +0 -0
- tests/unit/profiling/__init__.py +0 -0
- tests/unit/profiling/conftest.py +81 -0
- tests/unit/profiling/test_boolean_profiler.py +91 -0
- tests/unit/profiling/test_categorical_profiler.py +182 -0
- tests/unit/profiling/test_correlation_profiler.py +124 -0
- tests/unit/profiling/test_datetime_profiler.py +133 -0
- tests/unit/profiling/test_missingness_profiler.py +51 -0
- tests/unit/profiling/test_numeric_profiler.py +212 -0
- tests/unit/profiling/test_target_profiler.py +44 -0
- tests/unit/profiling/test_text_profiler.py +61 -0
- tests/unit/profiling/test_type_detector.py +32 -0
- tests/unit/splitting/__init__.py +0 -0
- tests/unit/splitting/test_data_splitter.py +417 -0
- utils/__init__.py +0 -0
- utils/data_loader.py +110 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CategoricalProfiler – Phase 1 extension: Categorical Column Profiling.
|
|
3
|
+
|
|
4
|
+
Per-column metrics (opt-in via ProfileConfig.categorical_columns):
|
|
5
|
+
1. Cardinality & unique ratio
|
|
6
|
+
2. Ordinal vs nominal detection
|
|
7
|
+
3. Top-5 value counts with percentages
|
|
8
|
+
4. Rare category analysis (<1 % frequency threshold)
|
|
9
|
+
5. Whitespace-only value count
|
|
10
|
+
6. Mixed-type flag (some values numeric, some not)
|
|
11
|
+
7. Free-text / natural-language flag
|
|
12
|
+
(avg word count >5 OR avg char length >50 OR avg token count >10)
|
|
13
|
+
8. Imbalance metrics
|
|
14
|
+
– class ratio (max_freq / min_freq)
|
|
15
|
+
– Shannon entropy
|
|
16
|
+
– Gini impurity
|
|
17
|
+
|
|
18
|
+
Integration
|
|
19
|
+
-----------
|
|
20
|
+
Add `categorical_columns: list[str] | None` to ProfileConfig, then call::
|
|
21
|
+
|
|
22
|
+
from profiling.categorical import CategoricalProfiler
|
|
23
|
+
|
|
24
|
+
cat_profiler = CategoricalProfiler(
|
|
25
|
+
columns=["status", "country", "product_type"],
|
|
26
|
+
config=cfg,
|
|
27
|
+
)
|
|
28
|
+
cat_result = cat_profiler.profile(df)
|
|
29
|
+
|
|
30
|
+
The result is a CategoricalProfileResult; attach it to TabularProfileResult
|
|
31
|
+
however suits your downstream pipeline.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
import math
|
|
37
|
+
|
|
38
|
+
import polars as pl
|
|
39
|
+
from ._base import ColumnBatchProfiler
|
|
40
|
+
from ._categorical_config import (
|
|
41
|
+
CategoricalProfileResult,
|
|
42
|
+
CategoricalStats,
|
|
43
|
+
TopValueEntry,
|
|
44
|
+
CategoricalFlag,
|
|
45
|
+
RareCategoryStats,
|
|
46
|
+
ImbalanceMetrics,
|
|
47
|
+
)
|
|
48
|
+
from .config import (
|
|
49
|
+
ProfileConfig,
|
|
50
|
+
SemanticType,
|
|
51
|
+
)
|
|
52
|
+
from ..models._data_types import _CAT_DTYPES
|
|
53
|
+
|
|
54
|
+
# ---------------------------------------------------------------------------
|
|
55
|
+
# Module-level thresholds (documented so callers can see what drives flags)
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
_RARE_THRESHOLD_PCT: float = 0.01 # <1 % of rows → rare
|
|
59
|
+
_MIXED_TYPE_MIN_MINOR_PCT: float = 0.05
|
|
60
|
+
_MIXED_TYPE_Z_SCORE: float = 1.96
|
|
61
|
+
|
|
62
|
+
_NEAR_CONSTANT_THRESHOLD: float = 0.90
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class CategoricalProfiler(ColumnBatchProfiler[CategoricalProfileResult]):
|
|
66
|
+
"""
|
|
67
|
+
Categorical profiler for Polars DataFrames.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
columns : list[str]
|
|
72
|
+
Columns to profile. The profiler intersects this list with
|
|
73
|
+
the DataFrame's actual columns at runtime.
|
|
74
|
+
config : ProfileConfig | None
|
|
75
|
+
Shared profiling configuration (used for chunk_size, etc.).
|
|
76
|
+
|
|
77
|
+
Usage
|
|
78
|
+
-----
|
|
79
|
+
>>> profiler = CategoricalProfiler(
|
|
80
|
+
... columns=["status", "country", "product_type"],
|
|
81
|
+
... )
|
|
82
|
+
>>> result = profiler.profile(df)
|
|
83
|
+
>>> print(result)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
config: ProfileConfig | None = None,
|
|
89
|
+
) -> None:
|
|
90
|
+
super().__init__(config)
|
|
91
|
+
|
|
92
|
+
# ------------------------------------------------------------------
|
|
93
|
+
# Public API
|
|
94
|
+
# ------------------------------------------------------------------
|
|
95
|
+
|
|
96
|
+
def profile(
|
|
97
|
+
self,
|
|
98
|
+
data: pl.DataFrame,
|
|
99
|
+
columns: list[str],
|
|
100
|
+
) -> CategoricalProfileResult:
|
|
101
|
+
return self._run(data, columns)
|
|
102
|
+
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
# Orchestration
|
|
105
|
+
# ------------------------------------------------------------------
|
|
106
|
+
|
|
107
|
+
def _eligible(
|
|
108
|
+
self,
|
|
109
|
+
series: pl.Series,
|
|
110
|
+
) -> bool:
|
|
111
|
+
override = self.config.column_overrides.get(series.name)
|
|
112
|
+
if override == SemanticType.Categorical:
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
if override is not None:
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
return series.dtype in _CAT_DTYPES
|
|
119
|
+
|
|
120
|
+
def _run(
|
|
121
|
+
self,
|
|
122
|
+
df: pl.DataFrame,
|
|
123
|
+
columns: list[str],
|
|
124
|
+
) -> CategoricalProfileResult:
|
|
125
|
+
result = CategoricalProfileResult()
|
|
126
|
+
|
|
127
|
+
# Resolve columns against actual schema
|
|
128
|
+
available = [
|
|
129
|
+
c
|
|
130
|
+
for c in self._resolve_columns(df.columns, columns)
|
|
131
|
+
if self._eligible(df[c])
|
|
132
|
+
]
|
|
133
|
+
result.analysed_columns = available
|
|
134
|
+
|
|
135
|
+
n_rows = df.height
|
|
136
|
+
|
|
137
|
+
for col_name in available:
|
|
138
|
+
series = df[col_name]
|
|
139
|
+
profile = self._profile_column(series, col_name, n_rows)
|
|
140
|
+
result.columns[col_name] = profile
|
|
141
|
+
|
|
142
|
+
return result
|
|
143
|
+
|
|
144
|
+
# ------------------------------------------------------------------
|
|
145
|
+
# Per-column driver
|
|
146
|
+
# ------------------------------------------------------------------
|
|
147
|
+
|
|
148
|
+
def _profile_column(
|
|
149
|
+
self,
|
|
150
|
+
series: pl.Series,
|
|
151
|
+
col_name: str,
|
|
152
|
+
n_rows: int,
|
|
153
|
+
) -> CategoricalStats:
|
|
154
|
+
profile = CategoricalStats()
|
|
155
|
+
|
|
156
|
+
# Cast to String for uniform downstream treatment
|
|
157
|
+
str_series = series.cast(pl.Utf8, strict=False)
|
|
158
|
+
|
|
159
|
+
# 1. Cardinality
|
|
160
|
+
self._compute_cardinality(str_series, profile, n_rows)
|
|
161
|
+
|
|
162
|
+
# 3. Value distribution (top-5, rare categories, imbalance)
|
|
163
|
+
# Returns the value-count frame for reuse in later steps.
|
|
164
|
+
self._compute_value_distribution(str_series, profile, n_rows)
|
|
165
|
+
|
|
166
|
+
# 5. Mixed-type flag
|
|
167
|
+
# We already know from TypeDetector whether the column was numeric-
|
|
168
|
+
# coerced; here we detect columns that are *partly* numeric and
|
|
169
|
+
# partly not — a different (and more expensive) check.
|
|
170
|
+
self._check_mixed_type(str_series, profile)
|
|
171
|
+
|
|
172
|
+
return profile
|
|
173
|
+
|
|
174
|
+
# ------------------------------------------------------------------
|
|
175
|
+
# Step 1: Cardinality
|
|
176
|
+
# ------------------------------------------------------------------
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def _compute_cardinality(
|
|
180
|
+
series: pl.Series,
|
|
181
|
+
profile: CategoricalStats,
|
|
182
|
+
n_rows: int,
|
|
183
|
+
) -> None:
|
|
184
|
+
cardinality = series.drop_nulls().n_unique()
|
|
185
|
+
profile.cardinality = cardinality
|
|
186
|
+
profile.unique_ratio = cardinality / n_rows if n_rows > 0 else 0.0
|
|
187
|
+
|
|
188
|
+
# ------------------------------------------------------------------
|
|
189
|
+
# Step 2: Value distribution
|
|
190
|
+
# ------------------------------------------------------------------
|
|
191
|
+
|
|
192
|
+
def _compute_value_distribution(
|
|
193
|
+
self,
|
|
194
|
+
series: pl.Series,
|
|
195
|
+
profile: CategoricalStats,
|
|
196
|
+
n_rows: int,
|
|
197
|
+
) -> pl.DataFrame:
|
|
198
|
+
"""
|
|
199
|
+
Build value-count frame, populate top-5, rare stats, and imbalance.
|
|
200
|
+
Returns the full value-count DataFrame for possible reuse.
|
|
201
|
+
"""
|
|
202
|
+
# Exclude nulls and whitespace-only values from distribution stats
|
|
203
|
+
clean = series.filter(
|
|
204
|
+
~series.is_null()
|
|
205
|
+
& (series.str.strip_chars() != "")
|
|
206
|
+
& ~series.str.to_uppercase().is_in(["NA", "NAN", "NULL", "NONE", "?"])
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if clean.len() == 0:
|
|
210
|
+
return pl.DataFrame({"value": [], "count": []})
|
|
211
|
+
|
|
212
|
+
vc = clean.value_counts(sort=True).rename( # sorted descending by count
|
|
213
|
+
{"count": "count"}
|
|
214
|
+
) # polars already names it "count"
|
|
215
|
+
# Polars value_counts column name for the values is the series name
|
|
216
|
+
value_col = series.name
|
|
217
|
+
|
|
218
|
+
# --- Top-10 ---
|
|
219
|
+
top10_rows = min(10, vc.height)
|
|
220
|
+
profile.top_values = [
|
|
221
|
+
TopValueEntry(
|
|
222
|
+
value=vc[value_col][i],
|
|
223
|
+
count=int(vc["count"][i]),
|
|
224
|
+
percentage=int(vc["count"][i]) / n_rows if n_rows > 0 else 0.0,
|
|
225
|
+
)
|
|
226
|
+
for i in range(top10_rows)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
profile.mode_frequency = profile.top_values[0].percentage
|
|
230
|
+
if profile.mode_frequency > _NEAR_CONSTANT_THRESHOLD:
|
|
231
|
+
profile.flags.append(CategoricalFlag.NearConstant)
|
|
232
|
+
|
|
233
|
+
# --- Rare category analysis ---
|
|
234
|
+
rare_threshold_abs = max(1, math.floor(_RARE_THRESHOLD_PCT * n_rows))
|
|
235
|
+
rare_mask = vc["count"] < rare_threshold_abs
|
|
236
|
+
rare_rows = vc.filter(rare_mask)
|
|
237
|
+
|
|
238
|
+
profile.rare_categories = RareCategoryStats(
|
|
239
|
+
threshold_pct=_RARE_THRESHOLD_PCT,
|
|
240
|
+
rare_category_count=rare_rows.height,
|
|
241
|
+
total_rare_rows=(
|
|
242
|
+
int(rare_rows["count"].sum()) if rare_rows.height > 0 else 0
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
profile.rare_categories.rare_row_percentage = (
|
|
246
|
+
profile.rare_categories.total_rare_rows / n_rows if n_rows > 0 else 0.0
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# --- Imbalance metrics ---
|
|
250
|
+
# Class Ratio -> raw distribution
|
|
251
|
+
# Entropy -> randomness / information content
|
|
252
|
+
# Gini -> impurity / misclassification risk
|
|
253
|
+
counts = vc["count"].cast(pl.Float64)
|
|
254
|
+
total = float(counts.sum())
|
|
255
|
+
if total > 0:
|
|
256
|
+
probs = counts / total
|
|
257
|
+
max_freq = float(probs.max()) # type: ignore[arg-type]
|
|
258
|
+
min_freq = float(probs.min()) # type: ignore[arg-type]
|
|
259
|
+
|
|
260
|
+
class_ratio = max_freq / min_freq if min_freq > 0 else float("inf")
|
|
261
|
+
entropy = float(-(probs * probs.log(base=2)).fill_nan(0.0).sum())
|
|
262
|
+
gini = float(1.0 - (probs**2).sum())
|
|
263
|
+
|
|
264
|
+
profile.imbalance = ImbalanceMetrics(
|
|
265
|
+
class_ratio=class_ratio,
|
|
266
|
+
shannon_entropy=entropy,
|
|
267
|
+
gini_impurity=gini,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return vc
|
|
271
|
+
|
|
272
|
+
# ------------------------------------------------------------------
|
|
273
|
+
# Step 5: Mixed-type flag
|
|
274
|
+
# ------------------------------------------------------------------
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def _check_mixed_type(
|
|
278
|
+
series: pl.Series,
|
|
279
|
+
profile: CategoricalStats,
|
|
280
|
+
) -> None:
|
|
281
|
+
"""
|
|
282
|
+
Flag if the column contains both numeric-looking and non-numeric-looking
|
|
283
|
+
values. We use a regex pre-filter so that the vast majority of
|
|
284
|
+
clearly non-numeric strings are rejected cheaply, and we only
|
|
285
|
+
apply the heavier float-cast check to ambiguous values.
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
non_null = series.drop_nulls()
|
|
289
|
+
n_total = non_null.len()
|
|
290
|
+
|
|
291
|
+
if n_total == 0:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
numeric_cast = non_null.cast(pl.Float64, strict=False)
|
|
295
|
+
|
|
296
|
+
n_numeric = n_total - numeric_cast.null_count()
|
|
297
|
+
n_non_numeric = n_total - n_numeric
|
|
298
|
+
|
|
299
|
+
if n_numeric == 0 or n_non_numeric == 0:
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
n_minority = min(n_numeric, n_non_numeric)
|
|
303
|
+
p_minority = n_minority / n_total
|
|
304
|
+
|
|
305
|
+
z = _MIXED_TYPE_Z_SCORE
|
|
306
|
+
denominator = 1 + (z**2) / n_total
|
|
307
|
+
center = p_minority + (z**2) / (2 * n_total)
|
|
308
|
+
spread = z * math.sqrt(
|
|
309
|
+
(p_minority * (1 - p_minority) + (z**2) / (4 * n_total)) / n_total
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
lower_bound = (center - spread) / denominator
|
|
313
|
+
|
|
314
|
+
if lower_bound >= _MIXED_TYPE_MIN_MINOR_PCT:
|
|
315
|
+
profile.flags.append(CategoricalFlag.MixedType)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Result dataclasses for categorical column profiling.
|
|
3
|
+
|
|
4
|
+
These complement TabularProfileResult and are populated by
|
|
5
|
+
CategoricalProfiler, which is opt-in via ProfileConfig.categorical_columns.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import StrEnum
|
|
12
|
+
|
|
13
|
+
# ---------------------------------------------------------------------------
|
|
14
|
+
# Categorical stats dataclasses (canonical home — config.py re-exports these)
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CategoricalFlag(StrEnum):
|
|
19
|
+
MixedType = "mixed_type"
|
|
20
|
+
FreeText = "free_text"
|
|
21
|
+
NearConstant = "near_constant"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class TopValueEntry:
|
|
26
|
+
value: object
|
|
27
|
+
count: int
|
|
28
|
+
percentage: float
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class RareCategoryStats:
|
|
33
|
+
threshold_pct: float
|
|
34
|
+
rare_category_count: int = 0
|
|
35
|
+
total_rare_rows: int = 0
|
|
36
|
+
rare_row_percentage: float = 0.0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ImbalanceMetrics:
|
|
41
|
+
class_ratio: float = 0.0
|
|
42
|
+
shannon_entropy: float = 0.0
|
|
43
|
+
gini_impurity: float = 0.0
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class CategoricalStats:
|
|
48
|
+
cardinality: int = 0
|
|
49
|
+
unique_ratio: float = 0.0
|
|
50
|
+
mode_frequency: float = 0.0
|
|
51
|
+
top_values: list[TopValueEntry] = field(default_factory=list)
|
|
52
|
+
rare_categories: RareCategoryStats = field(
|
|
53
|
+
default_factory=lambda: RareCategoryStats(threshold_pct=0.01),
|
|
54
|
+
)
|
|
55
|
+
imbalance: ImbalanceMetrics = field(default_factory=ImbalanceMetrics)
|
|
56
|
+
flags: list[CategoricalFlag] = field(default_factory=list)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
CategoricalColumnProfile = CategoricalStats
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
# Top-level result
|
|
64
|
+
# ---------------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class CategoricalProfileResult:
|
|
69
|
+
"""
|
|
70
|
+
Categorical profile for all opted-in columns.
|
|
71
|
+
|
|
72
|
+
Attributes
|
|
73
|
+
----------
|
|
74
|
+
columns : dict[str, CategoricalColumnProfile]
|
|
75
|
+
Per-column profiles, keyed by column name.
|
|
76
|
+
analysed_columns : list[str]
|
|
77
|
+
Columns that were actually profiled (after schema intersection).
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
columns: dict[str, CategoricalStats] = field(default_factory=dict)
|
|
81
|
+
analysed_columns: list[str] = field(default_factory=list)
|
|
82
|
+
|
|
83
|
+
def __str__(self) -> str: # pragma: no cover
|
|
84
|
+
lines = ["=== Categorical Profile ==="]
|
|
85
|
+
for profile in self.columns.values():
|
|
86
|
+
lines.append(str(profile))
|
|
87
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Result dataclasses for correlation and information-structure profiling.
|
|
3
|
+
|
|
4
|
+
Populated by CorrelationProfiler, which is opt-in via
|
|
5
|
+
ProfileConfig.correlation_target_column (and implicitly by passing
|
|
6
|
+
numeric/categorical column lists that are already resolved upstream).
|
|
7
|
+
|
|
8
|
+
Design notes
|
|
9
|
+
------------
|
|
10
|
+
- Pearson matrix : linear relationships between numeric columns.
|
|
11
|
+
- Spearman matrix : monotonic (rank-based) relationships; robust to
|
|
12
|
+
outliers and non-linearity.
|
|
13
|
+
- Near-redundancy : any pair with |r| > 0.95 flagged — identical signal,
|
|
14
|
+
one should be dropped before modelling.
|
|
15
|
+
- Feature–target : Pearson for numeric target, ANOVA F / eta² for
|
|
16
|
+
categorical target. Top-10 reported.
|
|
17
|
+
- Mutual information: MI for all features vs target (classif or regression).
|
|
18
|
+
Captures non-linear dependencies correlation misses.
|
|
19
|
+
"""
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from enum import StrEnum
|
|
24
|
+
from typing import Optional
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
# Enums
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
|
|
31
|
+
class CorrelationMethod(StrEnum):
|
|
32
|
+
Pearson = "pearson"
|
|
33
|
+
Spearman = "spearman"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TargetType(StrEnum):
|
|
37
|
+
Numeric = "numeric" # numeric target → Pearson + MI regression
|
|
38
|
+
Categorical = "categorical" # categorical target → ANOVA/eta² + MI classif
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
# Pairwise correlation result
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class CorrelationPair:
|
|
47
|
+
"""
|
|
48
|
+
A single entry in the pairwise correlation results.
|
|
49
|
+
|
|
50
|
+
Attributes
|
|
51
|
+
----------
|
|
52
|
+
col_a, col_b : str
|
|
53
|
+
The two column names (col_a < col_b lexicographically,
|
|
54
|
+
so each pair appears exactly once).
|
|
55
|
+
pearson_r : float | None
|
|
56
|
+
Pearson r. None when fewer than 3 non-null paired observations.
|
|
57
|
+
spearman_r : float | None
|
|
58
|
+
Spearman r. None under the same condition.
|
|
59
|
+
near_redundant : bool
|
|
60
|
+
True when max(|pearson_r|, |spearman_r|) > threshold (default 0.95).
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
col_a: str
|
|
64
|
+
col_b: str
|
|
65
|
+
pearson_r: Optional[float] = None
|
|
66
|
+
spearman_r: Optional[float] = None
|
|
67
|
+
near_redundant: bool = False
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# ---------------------------------------------------------------------------
|
|
71
|
+
# Feature–target entries
|
|
72
|
+
# ---------------------------------------------------------------------------
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class NumericTargetCorrelation:
|
|
76
|
+
"""
|
|
77
|
+
Pearson r between one numeric feature and a numeric target.
|
|
78
|
+
|
|
79
|
+
Attributes
|
|
80
|
+
----------
|
|
81
|
+
feature : str
|
|
82
|
+
pearson_r : float | None
|
|
83
|
+
"""
|
|
84
|
+
feature: str
|
|
85
|
+
pearson_r: Optional[float] = None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class CategoricalTargetCorrelation:
|
|
90
|
+
"""
|
|
91
|
+
ANOVA-based association between one categorical feature and a numeric
|
|
92
|
+
target (or a numeric feature vs a categorical target when the roles
|
|
93
|
+
are reversed — see CorrelationProfiler docs).
|
|
94
|
+
|
|
95
|
+
Attributes
|
|
96
|
+
----------
|
|
97
|
+
feature : str
|
|
98
|
+
f_statistic : float | None
|
|
99
|
+
One-way ANOVA F-statistic. Higher F → stronger group separation.
|
|
100
|
+
p_value : float | None
|
|
101
|
+
p-value for the F-test.
|
|
102
|
+
eta_squared : float | None
|
|
103
|
+
Effect size: SS_between / SS_total. Ranges [0, 1].
|
|
104
|
+
Rule of thumb: 0.01 small, 0.06 medium, 0.14 large.
|
|
105
|
+
"""
|
|
106
|
+
feature: str
|
|
107
|
+
f_statistic: Optional[float] = None
|
|
108
|
+
p_value: Optional[float] = None
|
|
109
|
+
eta_squared: Optional[float] = None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
# Mutual information
|
|
114
|
+
# ---------------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class MutualInformationEntry:
|
|
118
|
+
"""
|
|
119
|
+
MI score for one feature vs the target.
|
|
120
|
+
|
|
121
|
+
Attributes
|
|
122
|
+
----------
|
|
123
|
+
feature : str
|
|
124
|
+
mi_score : float
|
|
125
|
+
Raw MI value (nats, sklearn default). Not directly comparable
|
|
126
|
+
across datasets — use rank ordering within this dataset.
|
|
127
|
+
rank : int
|
|
128
|
+
1 = highest MI (most informative).
|
|
129
|
+
"""
|
|
130
|
+
feature: str
|
|
131
|
+
mi_score: float = 0.0
|
|
132
|
+
rank: int = 0
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# ---------------------------------------------------------------------------
|
|
136
|
+
# Near-redundancy summary
|
|
137
|
+
# ---------------------------------------------------------------------------
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class NearRedundancyGroup:
|
|
141
|
+
"""
|
|
142
|
+
A cluster of mutually near-redundant columns.
|
|
143
|
+
|
|
144
|
+
All pairs within the group exceed the |r| > 0.95 threshold.
|
|
145
|
+
The suggested_drop list contains every column except the first
|
|
146
|
+
alphabetically — a simple, deterministic heuristic.
|
|
147
|
+
"""
|
|
148
|
+
columns: list[str] = field(default_factory=list)
|
|
149
|
+
suggested_drop: list[str] = field(default_factory=list)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ---------------------------------------------------------------------------
|
|
153
|
+
# Top-level result
|
|
154
|
+
# ---------------------------------------------------------------------------
|
|
155
|
+
|
|
156
|
+
@dataclass
|
|
157
|
+
class CorrelationProfileResult:
|
|
158
|
+
"""
|
|
159
|
+
Full correlation and information-structure profile.
|
|
160
|
+
|
|
161
|
+
Attributes
|
|
162
|
+
----------
|
|
163
|
+
analysed_numeric_columns : list[str]
|
|
164
|
+
Numeric columns actually included in the pairwise matrices.
|
|
165
|
+
pairwise : list[CorrelationPair]
|
|
166
|
+
All (col_a, col_b) pairs, each carrying Pearson and Spearman r.
|
|
167
|
+
near_redundant_pairs : list[CorrelationPair]
|
|
168
|
+
Subset of *pairwise* where near_redundant is True.
|
|
169
|
+
near_redundancy_groups : list[NearRedundancyGroup]
|
|
170
|
+
Union-find clusters of near-redundant columns.
|
|
171
|
+
|
|
172
|
+
target_column : str | None
|
|
173
|
+
The target column supplied by the caller (may be None when no
|
|
174
|
+
target is provided — only pairwise matrices are then computed).
|
|
175
|
+
target_type : TargetType | None
|
|
176
|
+
|
|
177
|
+
feature_target_numeric : list[NumericTargetCorrelation]
|
|
178
|
+
Populated when target is numeric. Top-10 by |Pearson r|.
|
|
179
|
+
feature_target_categorical : list[CategoricalTargetCorrelation]
|
|
180
|
+
Populated when target is categorical. Top-10 by eta².
|
|
181
|
+
mutual_information : list[MutualInformationEntry]
|
|
182
|
+
All features ranked by MI vs target. Empty when no target.
|
|
183
|
+
|
|
184
|
+
pearson_matrix : dict[str, dict[str, float]]
|
|
185
|
+
Full symmetric Pearson matrix (numeric columns only).
|
|
186
|
+
spearman_matrix : dict[str, dict[str, float]]
|
|
187
|
+
Full symmetric Spearman matrix (numeric columns only).
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
# Column scope
|
|
191
|
+
analysed_numeric_columns: list[str] = field(default_factory=list)
|
|
192
|
+
|
|
193
|
+
# Pairwise matrices
|
|
194
|
+
pearson_matrix: dict[str, dict[str, float]] = field(default_factory=dict)
|
|
195
|
+
spearman_matrix: dict[str, dict[str, float]] = field(default_factory=dict)
|
|
196
|
+
|
|
197
|
+
# Pairwise summaries
|
|
198
|
+
pairwise: list[CorrelationPair] = field(default_factory=list)
|
|
199
|
+
near_redundant_pairs: list[CorrelationPair] = field(default_factory=list)
|
|
200
|
+
near_redundancy_groups: list[NearRedundancyGroup] = field(default_factory=list)
|
|
201
|
+
|
|
202
|
+
# Target info
|
|
203
|
+
target_column: Optional[str] = None
|
|
204
|
+
target_type: Optional[TargetType] = None
|
|
205
|
+
|
|
206
|
+
# Feature–target correlations (top-10 each)
|
|
207
|
+
feature_target_numeric: list[NumericTargetCorrelation] = field(default_factory=list)
|
|
208
|
+
feature_target_categorical: list[CategoricalTargetCorrelation] = field(default_factory=list)
|
|
209
|
+
|
|
210
|
+
# Mutual information (all features, ranked)
|
|
211
|
+
mutual_information: list[MutualInformationEntry] = field(default_factory=list)
|
|
212
|
+
|
|
213
|
+
# ------------------------------------------------------------------
|
|
214
|
+
# Convenience helpers
|
|
215
|
+
# ------------------------------------------------------------------
|
|
216
|
+
|
|
217
|
+
def top_mi(self, n: int = 10) -> list[MutualInformationEntry]:
|
|
218
|
+
"""Return the top-n features by mutual information score."""
|
|
219
|
+
return self.mutual_information[:n]
|
|
220
|
+
|
|
221
|
+
def get_pearson(self, col_a: str, col_b: str) -> Optional[float]:
|
|
222
|
+
return self.pearson_matrix.get(col_a, {}).get(col_b)
|
|
223
|
+
|
|
224
|
+
def get_spearman(self, col_a: str, col_b: str) -> Optional[float]:
|
|
225
|
+
return self.spearman_matrix.get(col_a, {}).get(col_b)
|