dataforge-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataforge_ml-0.1.0.dist-info/METADATA +34 -0
- dataforge_ml-0.1.0.dist-info/RECORD +54 -0
- dataforge_ml-0.1.0.dist-info/WHEEL +5 -0
- dataforge_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
- dataforge_ml-0.1.0.dist-info/top_level.txt +5 -0
- models/__init__.py +0 -0
- models/_data_structure.py +7 -0
- models/_data_types.py +12 -0
- profiling/__init__.py +35 -0
- profiling/_base.py +101 -0
- profiling/_boolean_config.py +37 -0
- profiling/_boolean_profiler.py +191 -0
- profiling/_categorical.py +315 -0
- profiling/_categorical_config.py +87 -0
- profiling/_correlation_config.py +225 -0
- profiling/_correlation_profiler.py +544 -0
- profiling/_datetime_config.py +98 -0
- profiling/_datetime_profiler.py +406 -0
- profiling/_missingness_config.py +137 -0
- profiling/_missingness_profiler.py +252 -0
- profiling/_numeric_config.py +116 -0
- profiling/_numeric_profiler.py +403 -0
- profiling/_tabular.py +249 -0
- profiling/_target_config.py +74 -0
- profiling/_target_profiler.py +156 -0
- profiling/_text_config.py +40 -0
- profiling/_text_profiler.py +194 -0
- profiling/_type_detector.py +463 -0
- profiling/config.py +236 -0
- profiling/structural.py +280 -0
- splitting/__init__.py +4 -0
- splitting/_config.py +56 -0
- splitting/_splitter.py +202 -0
- tests/__init__.py +0 -0
- tests/conftest.py +7 -0
- tests/integration/__init__.py +0 -0
- tests/integration/conftest.py +82 -0
- tests/integration/test_structural_end_to_end.py +219 -0
- tests/unit/__init__.py +0 -0
- tests/unit/profiling/__init__.py +0 -0
- tests/unit/profiling/conftest.py +81 -0
- tests/unit/profiling/test_boolean_profiler.py +91 -0
- tests/unit/profiling/test_categorical_profiler.py +182 -0
- tests/unit/profiling/test_correlation_profiler.py +124 -0
- tests/unit/profiling/test_datetime_profiler.py +133 -0
- tests/unit/profiling/test_missingness_profiler.py +51 -0
- tests/unit/profiling/test_numeric_profiler.py +212 -0
- tests/unit/profiling/test_target_profiler.py +44 -0
- tests/unit/profiling/test_text_profiler.py +61 -0
- tests/unit/profiling/test_type_detector.py +32 -0
- tests/unit/splitting/__init__.py +0 -0
- tests/unit/splitting/test_data_splitter.py +417 -0
- utils/__init__.py +0 -0
- utils/data_loader.py +110 -0
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from ....splitting._splitter import DataSplitter
|
|
5
|
+
from ....splitting._config import FoldResult, SplitResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# ---------------------------------------------------------------------------
|
|
9
|
+
# Fixtures
|
|
10
|
+
# ---------------------------------------------------------------------------
|
|
11
|
+
|
|
12
|
+
_N = 100
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture(scope="module")
|
|
16
|
+
def df() -> pl.DataFrame:
|
|
17
|
+
return pl.DataFrame(
|
|
18
|
+
{
|
|
19
|
+
"feature_a": pl.Series(list(range(_N)), dtype=pl.Float64),
|
|
20
|
+
"feature_b": pl.Series([i * 0.5 for i in range(_N)], dtype=pl.Float64),
|
|
21
|
+
"label": pl.Series(["cat" if i % 2 == 0 else "dog" for i in range(_N)], dtype=pl.Utf8),
|
|
22
|
+
}
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture(scope="module")
|
|
27
|
+
def df_no_target() -> pl.DataFrame:
|
|
28
|
+
return pl.DataFrame(
|
|
29
|
+
{
|
|
30
|
+
"x": pl.Series(list(range(_N)), dtype=pl.Float64),
|
|
31
|
+
"y": pl.Series(list(range(_N, _N * 2)), dtype=pl.Float64),
|
|
32
|
+
}
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ---------------------------------------------------------------------------
|
|
37
|
+
# Constructor validation
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_valid_construction(df):
|
|
42
|
+
splitter = DataSplitter(df, target="label", random_seed=42)
|
|
43
|
+
assert splitter._df is df
|
|
44
|
+
assert splitter._target == "label"
|
|
45
|
+
assert splitter._random_seed == 42
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_constructor_no_target(df_no_target):
|
|
49
|
+
splitter = DataSplitter(df_no_target)
|
|
50
|
+
assert splitter._target is None
|
|
51
|
+
assert splitter._random_seed is None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_constructor_raises_type_error_for_non_polars():
|
|
55
|
+
with pytest.raises(TypeError):
|
|
56
|
+
DataSplitter([[1, 2], [3, 4]])
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_constructor_raises_type_error_for_numpy_array():
|
|
60
|
+
import numpy as np
|
|
61
|
+
with pytest.raises(TypeError):
|
|
62
|
+
DataSplitter(np.zeros((10, 3)))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_constructor_raises_value_error_for_empty_df():
|
|
66
|
+
empty = pl.DataFrame({"x": pl.Series([], dtype=pl.Float64)})
|
|
67
|
+
with pytest.raises(ValueError, match="empty"):
|
|
68
|
+
DataSplitter(empty)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_constructor_raises_value_error_for_missing_target(df):
|
|
72
|
+
with pytest.raises(ValueError, match="not found"):
|
|
73
|
+
DataSplitter(df, target="nonexistent_column")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
# random_split — sizes and ratios
|
|
78
|
+
# ---------------------------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_random_split_sizes_sum_to_total(df):
|
|
82
|
+
splitter = DataSplitter(df, target="label", random_seed=0)
|
|
83
|
+
result = splitter.random_split(test_size=0.2)
|
|
84
|
+
assert result.train_size + result.test_size == len(df)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_random_split_dataframe_row_counts_match_sizes(df):
|
|
88
|
+
splitter = DataSplitter(df, target="label", random_seed=0)
|
|
89
|
+
result = splitter.random_split(test_size=0.2)
|
|
90
|
+
assert len(result.train) == result.train_size
|
|
91
|
+
assert len(result.test) == result.test_size
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_random_split_ratios_reflect_actual_proportions(df):
|
|
95
|
+
splitter = DataSplitter(df, target="label", random_seed=0)
|
|
96
|
+
result = splitter.random_split(test_size=0.2)
|
|
97
|
+
total = len(df)
|
|
98
|
+
assert result.train_ratio == pytest.approx(result.train_size / total)
|
|
99
|
+
assert result.test_ratio == pytest.approx(result.test_size / total)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_random_split_returns_split_result(df):
|
|
103
|
+
splitter = DataSplitter(df, target="label", random_seed=0)
|
|
104
|
+
result = splitter.random_split(test_size=0.2)
|
|
105
|
+
assert isinstance(result, SplitResult)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ---------------------------------------------------------------------------
|
|
109
|
+
# random_split — stratification
|
|
110
|
+
# ---------------------------------------------------------------------------
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_stratified_split_preserves_class_ratios(df):
|
|
114
|
+
splitter = DataSplitter(df, target="label", random_seed=42)
|
|
115
|
+
result = splitter.random_split(test_size=0.2, stratify=True)
|
|
116
|
+
original_ratio = df["label"].value_counts(sort=True)["count"].to_list()
|
|
117
|
+
train_counts = result.train["label"].value_counts(sort=True)["count"].to_list()
|
|
118
|
+
test_counts = result.test["label"].value_counts(sort=True)["count"].to_list()
|
|
119
|
+
# both splits should have roughly equal class representation (50/50 here)
|
|
120
|
+
train_ratio = train_counts[0] / sum(train_counts)
|
|
121
|
+
test_ratio = test_counts[0] / sum(test_counts)
|
|
122
|
+
assert abs(train_ratio - 0.5) < 0.1
|
|
123
|
+
assert abs(test_ratio - 0.5) < 0.1
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_stratify_false_produces_valid_split(df_no_target):
|
|
127
|
+
splitter = DataSplitter(df_no_target, random_seed=7)
|
|
128
|
+
result = splitter.random_split(test_size=0.3, stratify=False)
|
|
129
|
+
assert result.train_size + result.test_size == len(df_no_target)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def test_stratify_defaults_true_when_target_set(df):
|
|
133
|
+
splitter = DataSplitter(df, target="label", random_seed=1)
|
|
134
|
+
result = splitter.random_split(test_size=0.2)
|
|
135
|
+
assert result.train_size + result.test_size == len(df)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def test_stratify_defaults_false_when_no_target(df_no_target):
|
|
139
|
+
splitter = DataSplitter(df_no_target, random_seed=1)
|
|
140
|
+
result = splitter.random_split(test_size=0.2)
|
|
141
|
+
assert result.train_size + result.test_size == len(df_no_target)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_stratify_true_without_target_raises_value_error(df_no_target):
|
|
145
|
+
splitter = DataSplitter(df_no_target)
|
|
146
|
+
with pytest.raises(ValueError, match="target"):
|
|
147
|
+
splitter.random_split(test_size=0.2, stratify=True)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ---------------------------------------------------------------------------
|
|
151
|
+
# Reproducibility
|
|
152
|
+
# ---------------------------------------------------------------------------
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_same_seed_produces_identical_splits(df):
|
|
156
|
+
s1 = DataSplitter(df, target="label", random_seed=99)
|
|
157
|
+
s2 = DataSplitter(df, target="label", random_seed=99)
|
|
158
|
+
r1 = s1.random_split(test_size=0.2)
|
|
159
|
+
r2 = s2.random_split(test_size=0.2)
|
|
160
|
+
assert r1.train.equals(r2.train)
|
|
161
|
+
assert r1.test.equals(r2.test)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_different_seeds_produce_different_splits(df):
|
|
165
|
+
s1 = DataSplitter(df, target="label", random_seed=1)
|
|
166
|
+
s2 = DataSplitter(df, target="label", random_seed=2)
|
|
167
|
+
r1 = s1.random_split(test_size=0.2)
|
|
168
|
+
r2 = s2.random_split(test_size=0.2)
|
|
169
|
+
assert not r1.train.equals(r2.train)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# ---------------------------------------------------------------------------
|
|
173
|
+
# No profiling import leakage
|
|
174
|
+
# ---------------------------------------------------------------------------
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_no_profiling_import():
|
|
178
|
+
import splitting._splitter as mod
|
|
179
|
+
import sys
|
|
180
|
+
profiling_modules = [k for k in sys.modules if k.startswith("profiling")]
|
|
181
|
+
# DataSplitter module itself must not have caused profiling to be imported
|
|
182
|
+
assert "profiling" not in mod.__dict__
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# ---------------------------------------------------------------------------
|
|
186
|
+
# time_split — fixtures
|
|
187
|
+
# ---------------------------------------------------------------------------
|
|
188
|
+
|
|
189
|
+
from datetime import date, timedelta
|
|
190
|
+
|
|
191
|
+
_BASE = date(2024, 1, 1)
|
|
192
|
+
_TIME_N = 50
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@pytest.fixture(scope="module")
|
|
196
|
+
def time_df() -> pl.DataFrame:
|
|
197
|
+
dates = [_BASE + timedelta(days=i) for i in range(_TIME_N)]
|
|
198
|
+
return pl.DataFrame(
|
|
199
|
+
{
|
|
200
|
+
"date": pl.Series(dates, dtype=pl.Date),
|
|
201
|
+
"value": pl.Series(list(range(_TIME_N)), dtype=pl.Float64),
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@pytest.fixture(scope="module")
|
|
207
|
+
def time_splitter(time_df) -> DataSplitter:
|
|
208
|
+
return DataSplitter(time_df)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# ---------------------------------------------------------------------------
|
|
212
|
+
# time_split — error cases
|
|
213
|
+
# ---------------------------------------------------------------------------
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_time_split_raises_for_missing_column(time_splitter):
|
|
217
|
+
with pytest.raises(ValueError, match="not found"):
|
|
218
|
+
time_splitter.time_split("nonexistent")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def test_time_split_raises_when_neither_arg_provided(time_splitter):
|
|
222
|
+
with pytest.raises(ValueError, match="Either"):
|
|
223
|
+
time_splitter.time_split("date")
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
# ---------------------------------------------------------------------------
|
|
227
|
+
# time_split — fraction mode
|
|
228
|
+
# ---------------------------------------------------------------------------
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def test_fraction_mode_sizes_sum_to_total(time_df, time_splitter):
|
|
232
|
+
result = time_splitter.time_split("date", test_size=0.2)
|
|
233
|
+
assert result.train_size + result.test_size == len(time_df)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def test_fraction_mode_test_size_is_floor(time_df, time_splitter):
|
|
237
|
+
import math
|
|
238
|
+
result = time_splitter.time_split("date", test_size=0.2)
|
|
239
|
+
assert result.test_size == math.floor(len(time_df) * 0.2)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def test_fraction_mode_no_temporal_leakage(time_splitter):
|
|
243
|
+
result = time_splitter.time_split("date", test_size=0.2)
|
|
244
|
+
max_train = result.train["date"].max()
|
|
245
|
+
min_test = result.test["date"].min()
|
|
246
|
+
assert max_train < min_test
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def test_fraction_mode_metadata_accurate(time_df, time_splitter):
|
|
250
|
+
result = time_splitter.time_split("date", test_size=0.2)
|
|
251
|
+
total = len(time_df)
|
|
252
|
+
assert result.train_ratio == pytest.approx(result.train_size / total)
|
|
253
|
+
assert result.test_ratio == pytest.approx(result.test_size / total)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# ---------------------------------------------------------------------------
|
|
257
|
+
# time_split — cutoff mode
|
|
258
|
+
# ---------------------------------------------------------------------------
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def test_cutoff_mode_rows_before_cutoff_are_train(time_df, time_splitter):
|
|
262
|
+
cutoff = _BASE + timedelta(days=40)
|
|
263
|
+
result = time_splitter.time_split("date", cutoff=cutoff)
|
|
264
|
+
assert result.train["date"].max() < cutoff
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def test_cutoff_mode_rows_on_or_after_cutoff_are_test(time_df, time_splitter):
|
|
268
|
+
cutoff = _BASE + timedelta(days=40)
|
|
269
|
+
result = time_splitter.time_split("date", cutoff=cutoff)
|
|
270
|
+
assert result.test["date"].min() == cutoff
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def test_cutoff_mode_sizes_sum_to_total(time_df, time_splitter):
|
|
274
|
+
cutoff = _BASE + timedelta(days=40)
|
|
275
|
+
result = time_splitter.time_split("date", cutoff=cutoff)
|
|
276
|
+
assert result.train_size + result.test_size == len(time_df)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def test_cutoff_mode_no_temporal_leakage(time_splitter):
|
|
280
|
+
cutoff = _BASE + timedelta(days=25)
|
|
281
|
+
result = time_splitter.time_split("date", cutoff=cutoff)
|
|
282
|
+
assert result.train["date"].max() < result.test["date"].min()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# ---------------------------------------------------------------------------
|
|
286
|
+
# time_split — cutoff takes priority over test_size
|
|
287
|
+
# ---------------------------------------------------------------------------
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def test_cutoff_takes_priority_over_test_size(time_df, time_splitter):
|
|
291
|
+
cutoff = _BASE + timedelta(days=40)
|
|
292
|
+
# test_size=0.5 would give 25 test rows; cutoff=day40 gives 10 test rows
|
|
293
|
+
result_both = time_splitter.time_split("date", test_size=0.5, cutoff=cutoff)
|
|
294
|
+
result_cutoff_only = time_splitter.time_split("date", cutoff=cutoff)
|
|
295
|
+
assert result_both.test.equals(result_cutoff_only.test)
|
|
296
|
+
assert result_both.train.equals(result_cutoff_only.train)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# ---------------------------------------------------------------------------
|
|
300
|
+
# kfold — fixtures
|
|
301
|
+
# ---------------------------------------------------------------------------
|
|
302
|
+
|
|
303
|
+
_KFOLD_N = 100
|
|
304
|
+
_K = 5
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@pytest.fixture(scope="module")
|
|
308
|
+
def kfold_df() -> pl.DataFrame:
|
|
309
|
+
return pl.DataFrame(
|
|
310
|
+
{
|
|
311
|
+
"feature": pl.Series(list(range(_KFOLD_N)), dtype=pl.Float64),
|
|
312
|
+
"label": pl.Series(["A" if i % 2 == 0 else "B" for i in range(_KFOLD_N)], dtype=pl.Utf8),
|
|
313
|
+
}
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@pytest.fixture(scope="module")
|
|
318
|
+
def kfold_splitter(kfold_df) -> DataSplitter:
|
|
319
|
+
return DataSplitter(kfold_df, target="label", random_seed=42)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@pytest.fixture(scope="module")
|
|
323
|
+
def kfold_splitter_no_target(kfold_df) -> DataSplitter:
|
|
324
|
+
return DataSplitter(kfold_df, random_seed=42)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# ---------------------------------------------------------------------------
|
|
328
|
+
# kfold — basic structure
|
|
329
|
+
# ---------------------------------------------------------------------------
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def test_kfold_returns_exactly_k_folds(kfold_splitter):
|
|
333
|
+
folds = kfold_splitter.kfold(_K)
|
|
334
|
+
assert len(folds) == _K
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def test_kfold_fold_indices_zero_to_k_minus_one(kfold_splitter):
|
|
338
|
+
folds = kfold_splitter.kfold(_K)
|
|
339
|
+
assert [f.fold_index for f in folds] == list(range(_K))
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def test_kfold_returns_fold_result_instances(kfold_splitter):
|
|
343
|
+
folds = kfold_splitter.kfold(_K)
|
|
344
|
+
assert all(isinstance(f, FoldResult) for f in folds)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def test_kfold_sizes_sum_to_total(kfold_df, kfold_splitter):
|
|
348
|
+
folds = kfold_splitter.kfold(_K)
|
|
349
|
+
for fold in folds:
|
|
350
|
+
assert fold.train_size + fold.val_size == len(kfold_df)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def test_kfold_dataframe_row_counts_match_sizes(kfold_splitter):
|
|
354
|
+
folds = kfold_splitter.kfold(_K)
|
|
355
|
+
for fold in folds:
|
|
356
|
+
assert len(fold.train) == fold.train_size
|
|
357
|
+
assert len(fold.val) == fold.val_size
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# ---------------------------------------------------------------------------
|
|
361
|
+
# kfold — non-overlapping and complete coverage
|
|
362
|
+
# ---------------------------------------------------------------------------
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def test_kfold_val_sets_non_overlapping(kfold_df, kfold_splitter):
|
|
366
|
+
folds = kfold_splitter.kfold(_K)
|
|
367
|
+
# Collect all row hashes across val sets; no duplicates allowed
|
|
368
|
+
seen = set()
|
|
369
|
+
for fold in folds:
|
|
370
|
+
for row in fold.val.iter_rows():
|
|
371
|
+
assert row not in seen, f"Row {row} appeared in multiple val sets"
|
|
372
|
+
seen.add(row)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def test_kfold_val_sets_cover_all_rows(kfold_df, kfold_splitter):
|
|
376
|
+
folds = kfold_splitter.kfold(_K)
|
|
377
|
+
all_val_rows = set()
|
|
378
|
+
for fold in folds:
|
|
379
|
+
for row in fold.val.iter_rows():
|
|
380
|
+
all_val_rows.add(row)
|
|
381
|
+
all_df_rows = set(kfold_df.iter_rows())
|
|
382
|
+
assert all_val_rows == all_df_rows
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
# ---------------------------------------------------------------------------
|
|
386
|
+
# kfold — stratification
|
|
387
|
+
# ---------------------------------------------------------------------------
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def test_stratified_kfold_preserves_class_ratios(kfold_splitter):
|
|
391
|
+
folds = kfold_splitter.kfold(_K, stratify=True)
|
|
392
|
+
for fold in folds:
|
|
393
|
+
counts = fold.val["label"].value_counts()["count"].to_list()
|
|
394
|
+
ratio = counts[0] / sum(counts)
|
|
395
|
+
assert abs(ratio - 0.5) < 0.15
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def test_kfold_stratify_false_produces_valid_folds(kfold_df, kfold_splitter_no_target):
|
|
399
|
+
folds = kfold_splitter_no_target.kfold(_K, stratify=False)
|
|
400
|
+
assert len(folds) == _K
|
|
401
|
+
for fold in folds:
|
|
402
|
+
assert fold.train_size + fold.val_size == len(kfold_df)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def test_kfold_stratify_defaults_true_when_target_set(kfold_splitter):
|
|
406
|
+
folds = kfold_splitter.kfold(_K)
|
|
407
|
+
assert len(folds) == _K
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def test_kfold_stratify_defaults_false_when_no_target(kfold_df, kfold_splitter_no_target):
|
|
411
|
+
folds = kfold_splitter_no_target.kfold(_K)
|
|
412
|
+
assert len(folds) == _K
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def test_kfold_stratify_true_without_target_raises(kfold_splitter_no_target):
|
|
416
|
+
with pytest.raises(ValueError, match="target"):
|
|
417
|
+
kfold_splitter_no_target.kfold(_K, stratify=True)
|
utils/__init__.py
ADDED
|
File without changes
|
utils/data_loader.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union
|
|
6
|
+
import csv
|
|
7
|
+
import chardet
|
|
8
|
+
import polars as pl
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class UnsupportedFormatError(Exception):
|
|
12
|
+
"""Raised when a file extension has no registered loader"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
PathOrBuffer = Union[str, Path, io.IOBase, io.RawIOBase, io.BufferedIOBase]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _read_raw(source: PathOrBuffer) -> tuple[bytes, str | None]:
|
|
19
|
+
if isinstance(source, (str, Path)):
|
|
20
|
+
path = Path(source)
|
|
21
|
+
if not path.exists():
|
|
22
|
+
raise FileNotFoundError(f"no such file or directory: '{path}'")
|
|
23
|
+
|
|
24
|
+
ext = path.suffix.lower()
|
|
25
|
+
raw = path.read_bytes()
|
|
26
|
+
|
|
27
|
+
return raw, ext
|
|
28
|
+
|
|
29
|
+
pos = source.tell() if hasattr(source, "tell") else None
|
|
30
|
+
raw = source.read()
|
|
31
|
+
|
|
32
|
+
if pos is not None:
|
|
33
|
+
try:
|
|
34
|
+
source.seek(pos)
|
|
35
|
+
except Exception:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
return raw, None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _detect_encoding(raw: bytes) -> str:
|
|
42
|
+
result = chardet.detect(raw)
|
|
43
|
+
return result.get("encoding") or "utf-8"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _sniff_csv_delimiter(text: str) -> str:
|
|
47
|
+
sample = text[:4096]
|
|
48
|
+
try:
|
|
49
|
+
dialect = csv.Sniffer().sniff(sample, delimiters=",;\t|")
|
|
50
|
+
return dialect.delimiter
|
|
51
|
+
except csv.Error:
|
|
52
|
+
return ","
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _load_csv(raw: bytes) -> pl.DataFrame:
|
|
57
|
+
"""Load CSV/TSV bytes with auto-detected encoding and delimiter."""
|
|
58
|
+
encoding = _detect_encoding(raw)
|
|
59
|
+
text = raw.decode(encoding, errors="replace")
|
|
60
|
+
delimiter = _sniff_csv_delimiter(text)
|
|
61
|
+
return pl.read_csv(
|
|
62
|
+
io.BytesIO(raw),
|
|
63
|
+
separator=delimiter,
|
|
64
|
+
encoding=encoding,
|
|
65
|
+
infer_schema_length=10_000,
|
|
66
|
+
try_parse_dates=True,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_EXT_LOADERS: dict[str, callable] = {
|
|
71
|
+
".csv": _load_csv,
|
|
72
|
+
".tsv": _load_csv,
|
|
73
|
+
".parquet": lambda raw: pl.read_parquet(io.BytesIO(raw)),
|
|
74
|
+
".json": lambda raw: pl.read_json(io.BytesIO(raw)),
|
|
75
|
+
".ndjson": lambda raw: pl.read_ndjson(io.BytesIO(raw)),
|
|
76
|
+
".jsonl": lambda raw: pl.read_ndjson(io.BytesIO(raw)),
|
|
77
|
+
".xlsx": lambda raw: pl.read_excel(io.BytesIO(raw)),
|
|
78
|
+
".xls": lambda raw: pl.read_excel(io.BytesIO(raw)),
|
|
79
|
+
".arrow": lambda raw: pl.read_ipc(io.BytesIO(raw)),
|
|
80
|
+
".feather": lambda raw: pl.read_ipc(io.BytesIO(raw)),
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DataLoader:
|
|
85
|
+
def __init__(self, fmt: str | None = None) -> None:
|
|
86
|
+
self._fmt_override = fmt.lower() if fmt else None
|
|
87
|
+
|
|
88
|
+
def load(
|
|
89
|
+
self,
|
|
90
|
+
source: PathOrBuffer,
|
|
91
|
+
fmt: str | None = None,
|
|
92
|
+
) -> pl.DataFrame:
|
|
93
|
+
raw, ext_from_path = _read_raw(source)
|
|
94
|
+
|
|
95
|
+
resolved_fmt = (fmt or self._fmt_override or ext_from_path or "").lower()
|
|
96
|
+
|
|
97
|
+
if resolved_fmt not in _EXT_LOADERS:
|
|
98
|
+
label = resolved_fmt if resolved_fmt else "<unknown>"
|
|
99
|
+
raise UnsupportedFormatError(
|
|
100
|
+
f"Unsupported file format: '{label}'. "
|
|
101
|
+
f"Supported extensions: {sorted(_EXT_LOADERS)}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
loader = _EXT_LOADERS[resolved_fmt]
|
|
105
|
+
return loader(raw)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load(source: PathOrBuffer, fmt: str | None = None) -> pl.DataFrame:
|
|
109
|
+
"""Convenience wrapper — equivalent to ``DataLoader().load(source, fmt)``."""
|
|
110
|
+
return DataLoader().load(source, fmt=fmt)
|