nextrec 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +1 -2
- nextrec/basic/callback.py +1 -2
- nextrec/basic/features.py +39 -8
- nextrec/basic/layers.py +1 -2
- nextrec/basic/loggers.py +15 -10
- nextrec/basic/metrics.py +1 -2
- nextrec/basic/model.py +87 -85
- nextrec/basic/session.py +150 -0
- nextrec/data/__init__.py +13 -2
- nextrec/data/data_utils.py +74 -22
- nextrec/data/dataloader.py +513 -0
- nextrec/data/preprocessor.py +494 -134
- nextrec/loss/listwise.py +6 -0
- nextrec/loss/loss_utils.py +1 -2
- nextrec/loss/match_losses.py +4 -5
- nextrec/loss/pairwise.py +6 -0
- nextrec/loss/pointwise.py +6 -0
- nextrec/models/match/dssm.py +2 -2
- nextrec/models/match/dssm_v2.py +2 -2
- nextrec/models/match/mind.py +2 -2
- nextrec/models/match/sdm.py +2 -2
- nextrec/models/match/youtube_dnn.py +2 -2
- nextrec/models/multi_task/esmm.py +3 -3
- nextrec/models/multi_task/mmoe.py +3 -3
- nextrec/models/multi_task/ple.py +3 -3
- nextrec/models/multi_task/share_bottom.py +3 -3
- nextrec/models/ranking/afm.py +2 -3
- nextrec/models/ranking/autoint.py +3 -3
- nextrec/models/ranking/dcn.py +3 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/fibinet.py +3 -3
- nextrec/models/ranking/fm.py +3 -3
- nextrec/models/ranking/masknet.py +3 -3
- nextrec/models/ranking/pnn.py +3 -3
- nextrec/models/ranking/widedeep.py +3 -3
- nextrec/models/ranking/xdeepfm.py +3 -3
- nextrec/utils/__init__.py +4 -8
- nextrec/utils/embedding.py +2 -4
- nextrec/utils/initializer.py +1 -2
- nextrec/utils/optimizer.py +1 -2
- {nextrec-0.1.11.dist-info → nextrec-0.2.1.dist-info}/METADATA +3 -3
- nextrec-0.2.1.dist-info/RECORD +54 -0
- nextrec/basic/dataloader.py +0 -447
- nextrec/utils/common.py +0 -14
- nextrec-0.1.11.dist-info/RECORD +0 -51
- {nextrec-0.1.11.dist-info → nextrec-0.2.1.dist-info}/WHEEL +0 -0
- {nextrec-0.1.11.dist-info → nextrec-0.2.1.dist-info}/licenses/LICENSE +0 -0
nextrec/data/preprocessor.py
CHANGED
|
@@ -2,17 +2,17 @@
|
|
|
2
2
|
DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
3
3
|
|
|
4
4
|
Date: create on 13/11/2025
|
|
5
|
-
Author:
|
|
6
|
-
Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
6
|
"""
|
|
8
|
-
|
|
7
|
+
from __future__ import annotations
|
|
9
8
|
import os
|
|
10
|
-
import pandas as pd
|
|
11
|
-
import numpy as np
|
|
12
9
|
import pickle
|
|
13
10
|
import hashlib
|
|
14
11
|
import logging
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
15
14
|
|
|
15
|
+
from pathlib import Path
|
|
16
16
|
from typing import Dict, Union, Optional, Literal, Any
|
|
17
17
|
from sklearn.preprocessing import (
|
|
18
18
|
StandardScaler,
|
|
@@ -22,11 +22,18 @@ from sklearn.preprocessing import (
|
|
|
22
22
|
LabelEncoder
|
|
23
23
|
)
|
|
24
24
|
|
|
25
|
-
|
|
26
25
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
26
|
+
from nextrec.data.data_utils import (
|
|
27
|
+
resolve_file_paths,
|
|
28
|
+
iter_file_chunks,
|
|
29
|
+
read_table,
|
|
30
|
+
load_dataframes,
|
|
31
|
+
default_output_dir,
|
|
32
|
+
)
|
|
33
|
+
from nextrec.basic.session import create_session, resolve_save_path
|
|
34
|
+
from nextrec.basic.features import FeatureConfig
|
|
27
35
|
|
|
28
|
-
|
|
29
|
-
class DataProcessor:
|
|
36
|
+
class DataProcessor(FeatureConfig):
|
|
30
37
|
"""DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
31
38
|
|
|
32
39
|
Examples:
|
|
@@ -47,23 +54,26 @@ class DataProcessor:
|
|
|
47
54
|
>>> # Get vocabulary sizes for embedding layers
|
|
48
55
|
>>> vocab_sizes = processor.get_vocab_sizes()
|
|
49
56
|
"""
|
|
50
|
-
def __init__(self):
|
|
57
|
+
def __init__(self, session_id: str | None = None ):
|
|
51
58
|
self.numeric_features: Dict[str, Dict[str, Any]] = {}
|
|
52
59
|
self.sparse_features: Dict[str, Dict[str, Any]] = {}
|
|
53
60
|
self.sequence_features: Dict[str, Dict[str, Any]] = {}
|
|
54
61
|
self.target_features: Dict[str, Dict[str, Any]] = {}
|
|
55
|
-
|
|
62
|
+
self.session_id = session_id
|
|
63
|
+
self.session = create_session(session_id)
|
|
64
|
+
|
|
56
65
|
self.is_fitted = False
|
|
57
66
|
self._transform_summary_printed = False # Track if summary has been printed during transform
|
|
58
67
|
|
|
59
68
|
self.scalers: Dict[str, Any] = {}
|
|
60
69
|
self.label_encoders: Dict[str, LabelEncoder] = {}
|
|
61
70
|
self.target_encoders: Dict[str, Dict[str, int]] = {}
|
|
71
|
+
self._set_target_config([], [])
|
|
62
72
|
|
|
63
73
|
# Initialize logger if not already initialized
|
|
64
74
|
self._logger_initialized = False
|
|
65
75
|
if not logging.getLogger().hasHandlers():
|
|
66
|
-
setup_logger()
|
|
76
|
+
setup_logger(session_id=self.session_id)
|
|
67
77
|
self._logger_initialized = True
|
|
68
78
|
|
|
69
79
|
def add_numeric_feature(
|
|
@@ -126,6 +136,7 @@ class DataProcessor:
|
|
|
126
136
|
'target_type': target_type,
|
|
127
137
|
'label_map': label_map
|
|
128
138
|
}
|
|
139
|
+
self._set_target_config(list(self.target_features.keys()), [])
|
|
129
140
|
|
|
130
141
|
def _hash_string(self, s: str, hash_size: int) -> int:
|
|
131
142
|
return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
|
|
@@ -212,30 +223,35 @@ class DataProcessor:
|
|
|
212
223
|
data: pd.Series,
|
|
213
224
|
config: Dict[str, Any]
|
|
214
225
|
) -> np.ndarray:
|
|
215
|
-
|
|
226
|
+
"""Fast path sparse feature transform using cached dict mapping or hashing."""
|
|
216
227
|
name = str(data.name)
|
|
217
228
|
encode_method = config['encode_method']
|
|
218
229
|
fill_na = config['fill_na']
|
|
219
230
|
|
|
220
|
-
|
|
221
|
-
|
|
231
|
+
sparse_series = pd.Series(data, name=name).fillna(fill_na).astype(str)
|
|
232
|
+
|
|
222
233
|
if encode_method == 'label':
|
|
223
234
|
le = self.label_encoders.get(name)
|
|
224
235
|
if le is None:
|
|
225
236
|
raise ValueError(f"LabelEncoder for {name} not fitted")
|
|
226
237
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
return
|
|
235
|
-
|
|
236
|
-
|
|
238
|
+
class_to_idx = config.get('_class_to_idx')
|
|
239
|
+
if class_to_idx is None:
|
|
240
|
+
class_to_idx = {cls: idx for idx, cls in enumerate(le.classes_)}
|
|
241
|
+
config['_class_to_idx'] = class_to_idx
|
|
242
|
+
|
|
243
|
+
encoded = sparse_series.map(class_to_idx)
|
|
244
|
+
encoded = encoded.fillna(0).astype(np.int64)
|
|
245
|
+
return encoded.to_numpy()
|
|
246
|
+
|
|
247
|
+
if encode_method == 'hash':
|
|
237
248
|
hash_size = config['hash_size']
|
|
238
|
-
|
|
249
|
+
hash_fn = self._hash_string
|
|
250
|
+
return np.fromiter(
|
|
251
|
+
(hash_fn(v, hash_size) for v in sparse_series.to_numpy()),
|
|
252
|
+
dtype=np.int64,
|
|
253
|
+
count=sparse_series.size,
|
|
254
|
+
)
|
|
239
255
|
|
|
240
256
|
return np.array([], dtype=np.int64)
|
|
241
257
|
|
|
@@ -282,64 +298,78 @@ class DataProcessor:
|
|
|
282
298
|
data: pd.Series,
|
|
283
299
|
config: Dict[str, Any]
|
|
284
300
|
) -> np.ndarray:
|
|
301
|
+
"""Optimized sequence transform with preallocation and cached vocab map."""
|
|
285
302
|
name = str(data.name)
|
|
286
303
|
encode_method = config['encode_method']
|
|
287
304
|
max_len = config['max_len']
|
|
288
305
|
pad_value = config['pad_value']
|
|
289
306
|
truncate = config['truncate']
|
|
290
307
|
separator = config['separator']
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
308
|
+
|
|
309
|
+
arr = np.asarray(data, dtype=object)
|
|
310
|
+
n = arr.shape[0]
|
|
311
|
+
output = np.full((n, max_len), pad_value, dtype=np.int64)
|
|
312
|
+
|
|
313
|
+
# Shared helpers cached locally for speed and cross-platform consistency
|
|
314
|
+
split_fn = str.split
|
|
315
|
+
is_nan = np.isnan
|
|
316
|
+
|
|
317
|
+
if encode_method == 'label':
|
|
318
|
+
le = self.label_encoders.get(name)
|
|
319
|
+
if le is None:
|
|
320
|
+
raise ValueError(f"LabelEncoder for {name} not fitted")
|
|
321
|
+
class_to_idx = config.get('_class_to_idx')
|
|
322
|
+
if class_to_idx is None:
|
|
323
|
+
class_to_idx = {cls: idx for idx, cls in enumerate(le.classes_)}
|
|
324
|
+
config['_class_to_idx'] = class_to_idx
|
|
325
|
+
else:
|
|
326
|
+
class_to_idx = None # type: ignore
|
|
327
|
+
|
|
328
|
+
hash_fn = self._hash_string
|
|
329
|
+
hash_size = config.get('hash_size')
|
|
330
|
+
|
|
331
|
+
for i, seq in enumerate(arr):
|
|
332
|
+
# normalize sequence to a list of strings
|
|
294
333
|
tokens = []
|
|
295
|
-
|
|
296
334
|
if seq is None:
|
|
297
335
|
tokens = []
|
|
298
|
-
elif isinstance(seq, (float, np.floating))
|
|
299
|
-
tokens = []
|
|
336
|
+
elif isinstance(seq, (float, np.floating)):
|
|
337
|
+
tokens = [] if is_nan(seq) else [str(seq)]
|
|
300
338
|
elif isinstance(seq, str):
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
tokens = seq.split(separator)
|
|
305
|
-
elif isinstance(seq, (list, tuple)):
|
|
339
|
+
seq_str = seq.strip()
|
|
340
|
+
tokens = [] if not seq_str else split_fn(seq_str, separator)
|
|
341
|
+
elif isinstance(seq, (list, tuple, np.ndarray)):
|
|
306
342
|
tokens = [str(t) for t in seq]
|
|
307
|
-
elif isinstance(seq, np.ndarray):
|
|
308
|
-
tokens = [str(t) for t in seq.tolist()]
|
|
309
343
|
else:
|
|
310
344
|
tokens = []
|
|
311
|
-
|
|
345
|
+
|
|
312
346
|
if encode_method == 'label':
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
token_str = str(token).strip()
|
|
320
|
-
if token_str and token_str in le.classes_:
|
|
321
|
-
encoded_val = le.transform([token_str])
|
|
322
|
-
encoded.append(int(encoded_val[0]))
|
|
323
|
-
else:
|
|
324
|
-
encoded.append(0) # UNK
|
|
347
|
+
encoded = [
|
|
348
|
+
class_to_idx.get(token.strip(), 0) # type: ignore[union-attr]
|
|
349
|
+
for token in tokens
|
|
350
|
+
if token is not None and token != ''
|
|
351
|
+
]
|
|
352
|
+
|
|
325
353
|
elif encode_method == 'hash':
|
|
326
|
-
hash_size
|
|
327
|
-
|
|
354
|
+
if hash_size is None:
|
|
355
|
+
raise ValueError("hash_size must be set for hash encoding")
|
|
356
|
+
encoded = [
|
|
357
|
+
hash_fn(str(token), hash_size)
|
|
358
|
+
for token in tokens
|
|
359
|
+
if str(token).strip()
|
|
360
|
+
]
|
|
328
361
|
else:
|
|
329
362
|
encoded = []
|
|
330
|
-
|
|
363
|
+
|
|
364
|
+
if not encoded:
|
|
365
|
+
continue
|
|
366
|
+
|
|
331
367
|
if len(encoded) > max_len:
|
|
332
|
-
if truncate == 'pre'
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
padding = [pad_value] * (max_len - len(encoded))
|
|
338
|
-
encoded = encoded + padding
|
|
339
|
-
|
|
340
|
-
result.append(encoded)
|
|
341
|
-
|
|
342
|
-
return np.array(result, dtype=np.int64)
|
|
368
|
+
encoded = encoded[-max_len:] if truncate == 'pre' else encoded[:max_len]
|
|
369
|
+
|
|
370
|
+
output[i, : len(encoded)] = encoded
|
|
371
|
+
|
|
372
|
+
return output
|
|
343
373
|
|
|
344
374
|
def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
345
375
|
name = str(data.name)
|
|
@@ -393,54 +423,212 @@ class DataProcessor:
|
|
|
393
423
|
|
|
394
424
|
return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
|
|
395
425
|
|
|
396
|
-
|
|
397
|
-
|
|
426
|
+
def _load_dataframe_from_path(self, path: str) -> pd.DataFrame:
|
|
427
|
+
"""Load all data from a file or directory path into a single DataFrame."""
|
|
428
|
+
file_paths, file_type = resolve_file_paths(path)
|
|
429
|
+
frames = load_dataframes(file_paths, file_type)
|
|
430
|
+
return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
431
|
+
|
|
432
|
+
def _extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
|
|
433
|
+
"""Extract sequence tokens from a single value."""
|
|
434
|
+
if value is None:
|
|
435
|
+
return []
|
|
436
|
+
if isinstance(value, (float, np.floating)) and np.isnan(value):
|
|
437
|
+
return []
|
|
438
|
+
if isinstance(value, str):
|
|
439
|
+
stripped = value.strip()
|
|
440
|
+
return [] if not stripped else stripped.split(separator)
|
|
441
|
+
if isinstance(value, (list, tuple, np.ndarray)):
|
|
442
|
+
return [str(v) for v in value]
|
|
443
|
+
return [str(value)]
|
|
444
|
+
|
|
445
|
+
def _fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
|
|
446
|
+
"""Fit processor statistics by streaming files to reduce memory usage."""
|
|
398
447
|
logger = logging.getLogger()
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
448
|
+
logger.info(colorize("Fitting DataProcessor (streaming path mode)...", color="cyan", bold=True))
|
|
449
|
+
file_paths, file_type = resolve_file_paths(path)
|
|
450
|
+
|
|
451
|
+
numeric_acc: Dict[str, Dict[str, float]] = {}
|
|
452
|
+
for name in self.numeric_features.keys():
|
|
453
|
+
numeric_acc[name] = {
|
|
454
|
+
"sum": 0.0,
|
|
455
|
+
"sumsq": 0.0,
|
|
456
|
+
"count": 0.0,
|
|
457
|
+
"min": np.inf,
|
|
458
|
+
"max": -np.inf,
|
|
459
|
+
"max_abs": 0.0,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
sparse_vocab: Dict[str, set[str]] = {name: set() for name in self.sparse_features.keys()}
|
|
463
|
+
seq_vocab: Dict[str, set[str]] = {name: set() for name in self.sequence_features.keys()}
|
|
464
|
+
target_values: Dict[str, set[Any]] = {name: set() for name in self.target_features.keys()}
|
|
404
465
|
|
|
466
|
+
missing_features = set()
|
|
467
|
+
|
|
468
|
+
for file_path in file_paths:
|
|
469
|
+
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
470
|
+
# numeric features
|
|
471
|
+
for name, config in self.numeric_features.items():
|
|
472
|
+
if name not in chunk.columns:
|
|
473
|
+
missing_features.add(name)
|
|
474
|
+
continue
|
|
475
|
+
series = chunk[name]
|
|
476
|
+
values = pd.to_numeric(series, errors="coerce")
|
|
477
|
+
values = values.dropna()
|
|
478
|
+
if values.empty:
|
|
479
|
+
continue
|
|
480
|
+
acc = numeric_acc[name]
|
|
481
|
+
arr = values.to_numpy(dtype=np.float64, copy=False)
|
|
482
|
+
acc["count"] += arr.size
|
|
483
|
+
acc["sum"] += float(arr.sum())
|
|
484
|
+
acc["sumsq"] += float(np.square(arr).sum())
|
|
485
|
+
acc["min"] = min(acc["min"], float(arr.min()))
|
|
486
|
+
acc["max"] = max(acc["max"], float(arr.max()))
|
|
487
|
+
acc["max_abs"] = max(acc["max_abs"], float(np.abs(arr).max()))
|
|
488
|
+
|
|
489
|
+
# sparse features
|
|
490
|
+
for name, config in self.sparse_features.items():
|
|
491
|
+
if name not in chunk.columns:
|
|
492
|
+
missing_features.add(name)
|
|
493
|
+
continue
|
|
494
|
+
fill_na = config["fill_na"]
|
|
495
|
+
series = chunk[name].fillna(fill_na).astype(str)
|
|
496
|
+
sparse_vocab[name].update(series.tolist())
|
|
497
|
+
|
|
498
|
+
# sequence features
|
|
499
|
+
for name, config in self.sequence_features.items():
|
|
500
|
+
if name not in chunk.columns:
|
|
501
|
+
missing_features.add(name)
|
|
502
|
+
continue
|
|
503
|
+
separator = config["separator"]
|
|
504
|
+
series = chunk[name]
|
|
505
|
+
tokens = []
|
|
506
|
+
for val in series:
|
|
507
|
+
tokens.extend(self._extract_sequence_tokens(val, separator))
|
|
508
|
+
seq_vocab[name].update(tokens)
|
|
509
|
+
|
|
510
|
+
# target features
|
|
511
|
+
for name in self.target_features.keys():
|
|
512
|
+
if name not in chunk.columns:
|
|
513
|
+
missing_features.add(name)
|
|
514
|
+
continue
|
|
515
|
+
vals = chunk[name].dropna().tolist()
|
|
516
|
+
target_values[name].update(vals)
|
|
517
|
+
|
|
518
|
+
if missing_features:
|
|
519
|
+
logger.warning(
|
|
520
|
+
f"The following configured features were not found in provided files: {sorted(missing_features)}"
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
# finalize numeric scalers
|
|
405
524
|
for name, config in self.numeric_features.items():
|
|
406
|
-
|
|
407
|
-
|
|
525
|
+
acc = numeric_acc[name]
|
|
526
|
+
if acc["count"] == 0:
|
|
527
|
+
logger.warning(f"Numeric feature {name} has no valid values in provided files")
|
|
408
528
|
continue
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
529
|
+
|
|
530
|
+
mean_val = acc["sum"] / acc["count"]
|
|
531
|
+
if config["fill_na"] is not None:
|
|
532
|
+
config["fill_na_value"] = config["fill_na"]
|
|
533
|
+
else:
|
|
534
|
+
config["fill_na_value"] = mean_val
|
|
535
|
+
|
|
536
|
+
scaler_type = config["scaler"]
|
|
537
|
+
if scaler_type == "standard":
|
|
538
|
+
var = max(acc["sumsq"] / acc["count"] - mean_val * mean_val, 0.0)
|
|
539
|
+
scaler = StandardScaler()
|
|
540
|
+
scaler.mean_ = np.array([mean_val], dtype=np.float64)
|
|
541
|
+
scaler.var_ = np.array([var], dtype=np.float64)
|
|
542
|
+
scaler.scale_ = np.array([np.sqrt(var) if var > 0 else 1.0], dtype=np.float64)
|
|
543
|
+
scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
|
|
544
|
+
self.scalers[name] = scaler
|
|
545
|
+
elif scaler_type == "minmax":
|
|
546
|
+
data_min = acc["min"] if np.isfinite(acc["min"]) else 0.0
|
|
547
|
+
data_max = acc["max"] if np.isfinite(acc["max"]) else data_min
|
|
548
|
+
scaler = MinMaxScaler()
|
|
549
|
+
scaler.data_min_ = np.array([data_min], dtype=np.float64)
|
|
550
|
+
scaler.data_max_ = np.array([data_max], dtype=np.float64)
|
|
551
|
+
scaler.data_range_ = scaler.data_max_ - scaler.data_min_
|
|
552
|
+
scaler.data_range_[scaler.data_range_ == 0] = 1.0
|
|
553
|
+
scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
|
|
554
|
+
self.scalers[name] = scaler
|
|
555
|
+
elif scaler_type == "maxabs":
|
|
556
|
+
scaler = MaxAbsScaler()
|
|
557
|
+
scaler.max_abs_ = np.array([acc["max_abs"]], dtype=np.float64)
|
|
558
|
+
scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
|
|
559
|
+
self.scalers[name] = scaler
|
|
560
|
+
elif scaler_type in ("log", "none", "robust"):
|
|
561
|
+
# log and none do not require fitting; robust requires full data and is handled earlier
|
|
414
562
|
continue
|
|
415
|
-
|
|
416
|
-
|
|
563
|
+
else:
|
|
564
|
+
raise ValueError(f"Unknown scaler type: {scaler_type}")
|
|
565
|
+
|
|
566
|
+
# finalize sparse label encoders
|
|
567
|
+
for name, config in self.sparse_features.items():
|
|
568
|
+
if config["encode_method"] == "label":
|
|
569
|
+
vocab = sparse_vocab[name]
|
|
570
|
+
if not vocab:
|
|
571
|
+
logger.warning(f"Sparse feature {name} has empty vocabulary")
|
|
572
|
+
continue
|
|
573
|
+
le = LabelEncoder()
|
|
574
|
+
le.fit(list(vocab))
|
|
575
|
+
self.label_encoders[name] = le
|
|
576
|
+
config["vocab_size"] = len(le.classes_)
|
|
577
|
+
elif config["encode_method"] == "hash":
|
|
578
|
+
config["vocab_size"] = config["hash_size"]
|
|
579
|
+
|
|
580
|
+
# finalize sequence vocabularies
|
|
417
581
|
for name, config in self.sequence_features.items():
|
|
418
|
-
if
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
582
|
+
if config["encode_method"] == "label":
|
|
583
|
+
vocab = seq_vocab[name] or {"<PAD>"}
|
|
584
|
+
le = LabelEncoder()
|
|
585
|
+
le.fit(list(vocab))
|
|
586
|
+
self.label_encoders[name] = le
|
|
587
|
+
config["vocab_size"] = len(le.classes_)
|
|
588
|
+
elif config["encode_method"] == "hash":
|
|
589
|
+
config["vocab_size"] = config["hash_size"]
|
|
422
590
|
|
|
591
|
+
# finalize targets
|
|
423
592
|
for name, config in self.target_features.items():
|
|
424
|
-
if
|
|
425
|
-
logger.warning(f"Target {name}
|
|
593
|
+
if not target_values[name]:
|
|
594
|
+
logger.warning(f"Target {name} has no valid values in provided files")
|
|
426
595
|
continue
|
|
427
|
-
|
|
428
|
-
|
|
596
|
+
|
|
597
|
+
target_type = config["target_type"]
|
|
598
|
+
if target_type in ["binary", "multiclass"]:
|
|
599
|
+
unique_values = list(target_values[name])
|
|
600
|
+
try:
|
|
601
|
+
sorted_values = sorted(unique_values)
|
|
602
|
+
except TypeError:
|
|
603
|
+
sorted_values = sorted(unique_values, key=lambda x: str(x))
|
|
604
|
+
|
|
605
|
+
label_map = config["label_map"]
|
|
606
|
+
if label_map is None:
|
|
607
|
+
try:
|
|
608
|
+
int_values = [int(v) for v in sorted_values]
|
|
609
|
+
if int_values == list(range(len(int_values))):
|
|
610
|
+
label_map = {str(val): int(val) for val in sorted_values}
|
|
611
|
+
else:
|
|
612
|
+
label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
|
|
613
|
+
except (ValueError, TypeError):
|
|
614
|
+
label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
|
|
615
|
+
config["label_map"] = label_map
|
|
616
|
+
|
|
617
|
+
self.target_encoders[name] = label_map
|
|
618
|
+
|
|
429
619
|
self.is_fitted = True
|
|
430
|
-
logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
|
|
620
|
+
logger.info(colorize("DataProcessor fitted successfully (streaming path mode)", color="green", bold=True))
|
|
431
621
|
return self
|
|
432
|
-
|
|
433
|
-
def
|
|
434
|
-
self,
|
|
622
|
+
|
|
623
|
+
def _transform_in_memory(
|
|
624
|
+
self,
|
|
435
625
|
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
436
|
-
return_dict: bool
|
|
626
|
+
return_dict: bool,
|
|
627
|
+
persist: bool,
|
|
628
|
+
save_format: Optional[Literal["csv", "parquet"]],
|
|
437
629
|
) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
|
|
438
630
|
logger = logging.getLogger()
|
|
439
|
-
|
|
440
631
|
|
|
441
|
-
if not self.is_fitted:
|
|
442
|
-
raise ValueError("DataProcessor must be fitted before transform")
|
|
443
|
-
|
|
444
632
|
# Convert input to dict format for unified processing
|
|
445
633
|
if isinstance(data, pd.DataFrame):
|
|
446
634
|
data_dict = {col: data[col] for col in data.columns}
|
|
@@ -494,61 +682,233 @@ class DataProcessor:
|
|
|
494
682
|
series_data = pd.Series(data_dict[name], name=name)
|
|
495
683
|
processed = self._process_target_transform(series_data, config)
|
|
496
684
|
result_dict[name] = processed
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
return result_dict
|
|
500
|
-
else:
|
|
685
|
+
|
|
686
|
+
def _dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
|
|
501
687
|
# Convert all arrays to Series/lists at once to avoid fragmentation
|
|
502
688
|
columns_dict = {}
|
|
503
|
-
for key, value in
|
|
689
|
+
for key, value in result.items():
|
|
504
690
|
if key in self.sequence_features:
|
|
505
691
|
columns_dict[key] = [list(seq) for seq in value]
|
|
506
692
|
else:
|
|
507
693
|
columns_dict[key] = value
|
|
694
|
+
return pd.DataFrame(columns_dict)
|
|
695
|
+
|
|
696
|
+
assert save_format in [None, "csv", "parquet"], "save_format must be either 'csv', 'parquet', or None"
|
|
697
|
+
if persist and save_format is None:
|
|
698
|
+
save_format = "parquet"
|
|
699
|
+
|
|
700
|
+
result_df = None
|
|
701
|
+
if (not return_dict) or (save_format is not None):
|
|
702
|
+
result_df = _dict_to_dataframe(result_dict)
|
|
703
|
+
assert result_df is not None, "DataFrame is None after transform"
|
|
704
|
+
|
|
705
|
+
if save_format is not None:
|
|
706
|
+
save_path = resolve_save_path(
|
|
707
|
+
path=None,
|
|
708
|
+
default_dir=self.session_dir / "processor" / "preprocessed_data",
|
|
709
|
+
default_name="data_processed",
|
|
710
|
+
suffix=f".{save_format}",
|
|
711
|
+
add_timestamp=True,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
if save_format == "parquet":
|
|
715
|
+
result_df.to_parquet(save_path, index=False)
|
|
716
|
+
else:
|
|
717
|
+
result_df.to_csv(save_path, index=False)
|
|
718
|
+
|
|
719
|
+
logger.info(colorize(
|
|
720
|
+
f"Transformed data saved to: {save_path}",
|
|
721
|
+
color="green"
|
|
722
|
+
))
|
|
723
|
+
|
|
724
|
+
if return_dict:
|
|
725
|
+
return result_dict
|
|
726
|
+
return result_df
|
|
727
|
+
|
|
728
|
+
def _transform_path(self, path: str, output_path: Optional[str]) -> list[str]:
|
|
729
|
+
"""Transform data from files under a path and save them to a new location."""
|
|
730
|
+
logger = logging.getLogger()
|
|
731
|
+
|
|
732
|
+
file_paths, file_type = resolve_file_paths(path)
|
|
733
|
+
default_root = self.session_dir / "processor" / default_output_dir(path).name
|
|
734
|
+
output_root = default_root
|
|
735
|
+
target_file_override: Optional[Path] = None
|
|
736
|
+
|
|
737
|
+
if output_path:
|
|
738
|
+
output_path_obj = Path(output_path)
|
|
739
|
+
if not output_path_obj.is_absolute():
|
|
740
|
+
output_path_obj = self.session_dir / output_path_obj
|
|
741
|
+
if output_path_obj.suffix.lower() in {".csv", ".parquet"}:
|
|
742
|
+
if len(file_paths) != 1:
|
|
743
|
+
raise ValueError("output_path points to a file but multiple input files were provided.")
|
|
744
|
+
target_file_override = output_path_obj
|
|
745
|
+
output_root = output_path_obj.parent
|
|
746
|
+
else:
|
|
747
|
+
output_root = output_path_obj
|
|
748
|
+
|
|
749
|
+
output_root.mkdir(parents=True, exist_ok=True)
|
|
750
|
+
|
|
751
|
+
saved_paths: list[str] = []
|
|
752
|
+
for file_path in file_paths:
|
|
753
|
+
df = read_table(file_path, file_type)
|
|
754
|
+
|
|
755
|
+
transformed_df = self._transform_in_memory(
|
|
756
|
+
df,
|
|
757
|
+
return_dict=False,
|
|
758
|
+
persist=False,
|
|
759
|
+
save_format=None,
|
|
760
|
+
)
|
|
761
|
+
assert isinstance(transformed_df, pd.DataFrame), "Expected DataFrame when return_dict=False"
|
|
762
|
+
|
|
763
|
+
source_path = Path(file_path)
|
|
764
|
+
target_file = (
|
|
765
|
+
target_file_override
|
|
766
|
+
if target_file_override is not None
|
|
767
|
+
else output_root / f"{source_path.stem}_preprocessed{source_path.suffix}"
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
if file_type == "csv":
|
|
771
|
+
transformed_df.to_csv(target_file, index=False)
|
|
772
|
+
else:
|
|
773
|
+
transformed_df.to_parquet(target_file, index=False)
|
|
774
|
+
|
|
775
|
+
saved_paths.append(str(target_file.resolve()))
|
|
776
|
+
|
|
777
|
+
logger.info(colorize(
|
|
778
|
+
f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}",
|
|
779
|
+
color="green",
|
|
780
|
+
))
|
|
781
|
+
return saved_paths
|
|
782
|
+
|
|
783
|
+
# fit is nothing but registering the statistics from data so that we can transform the data later
|
|
784
|
+
def fit(
|
|
785
|
+
self,
|
|
786
|
+
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
787
|
+
chunk_size: int = 200000,
|
|
788
|
+
):
|
|
789
|
+
logger = logging.getLogger()
|
|
790
|
+
|
|
791
|
+
if isinstance(data, (str, os.PathLike)):
|
|
792
|
+
path_str = str(data)
|
|
793
|
+
uses_robust = any(cfg.get("scaler") == "robust" for cfg in self.numeric_features.values())
|
|
794
|
+
if uses_robust:
|
|
795
|
+
logger.warning(
|
|
796
|
+
"Robust scaler requires full data; loading all files into memory. "
|
|
797
|
+
"Consider smaller chunk_size or different scaler if memory is limited."
|
|
798
|
+
)
|
|
799
|
+
data = self._load_dataframe_from_path(path_str)
|
|
800
|
+
else:
|
|
801
|
+
return self._fit_from_path(path_str, chunk_size)
|
|
802
|
+
if isinstance(data, dict):
|
|
803
|
+
data = pd.DataFrame(data)
|
|
508
804
|
|
|
509
|
-
|
|
510
|
-
|
|
805
|
+
logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
|
|
806
|
+
|
|
807
|
+
for name, config in self.numeric_features.items():
|
|
808
|
+
if name not in data.columns:
|
|
809
|
+
logger.warning(f"Numeric feature {name} not found in data")
|
|
810
|
+
continue
|
|
811
|
+
self._process_numeric_feature_fit(data[name], config)
|
|
812
|
+
|
|
813
|
+
for name, config in self.sparse_features.items():
|
|
814
|
+
if name not in data.columns:
|
|
815
|
+
logger.warning(f"Sparse feature {name} not found in data")
|
|
816
|
+
continue
|
|
817
|
+
self._process_sparse_feature_fit(data[name], config)
|
|
818
|
+
|
|
819
|
+
for name, config in self.sequence_features.items():
|
|
820
|
+
if name not in data.columns:
|
|
821
|
+
logger.warning(f"Sequence feature {name} not found in data")
|
|
822
|
+
continue
|
|
823
|
+
self._process_sequence_feature_fit(data[name], config)
|
|
824
|
+
|
|
825
|
+
for name, config in self.target_features.items():
|
|
826
|
+
if name not in data.columns:
|
|
827
|
+
logger.warning(f"Target {name} not found in data")
|
|
828
|
+
continue
|
|
829
|
+
self._process_target_fit(data[name], config)
|
|
830
|
+
|
|
831
|
+
self.is_fitted = True
|
|
832
|
+
logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
|
|
833
|
+
return self
|
|
834
|
+
|
|
835
|
+
def transform(
|
|
836
|
+
self,
|
|
837
|
+
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
838
|
+
return_dict: bool = True,
|
|
839
|
+
persist: bool = False,
|
|
840
|
+
save_format: Optional[Literal["csv", "parquet"]] = None,
|
|
841
|
+
output_path: Optional[str] = None,
|
|
842
|
+
) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
|
|
843
|
+
logger = logging.getLogger()
|
|
844
|
+
|
|
845
|
+
if not self.is_fitted:
|
|
846
|
+
raise ValueError("DataProcessor must be fitted before transform")
|
|
847
|
+
|
|
848
|
+
if isinstance(data, (str, os.PathLike)):
|
|
849
|
+
if return_dict or persist or save_format is not None:
|
|
850
|
+
raise ValueError("Path transform writes files only; use output_path and leave return_dict/persist/save_format defaults.")
|
|
851
|
+
return self._transform_path(str(data), output_path)
|
|
852
|
+
|
|
853
|
+
return self._transform_in_memory(
|
|
854
|
+
data=data,
|
|
855
|
+
return_dict=return_dict,
|
|
856
|
+
persist=persist,
|
|
857
|
+
save_format=save_format,
|
|
858
|
+
)
|
|
511
859
|
|
|
512
860
|
def fit_transform(
|
|
513
861
|
self,
|
|
514
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
515
|
-
return_dict: bool = True
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
862
|
+
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
863
|
+
return_dict: bool = True,
|
|
864
|
+
save_format: Optional[Literal["csv", "parquet"]] = None,
|
|
865
|
+
output_path: Optional[str] = None,
|
|
866
|
+
chunk_size: int = 200000,
|
|
867
|
+
) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
|
|
868
|
+
self.fit(data, chunk_size=chunk_size)
|
|
869
|
+
return self.transform(
|
|
870
|
+
data,
|
|
871
|
+
return_dict=return_dict,
|
|
872
|
+
save_format=save_format,
|
|
873
|
+
output_path=output_path,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
def save(self, save_path: str):
|
|
521
877
|
logger = logging.getLogger()
|
|
522
|
-
|
|
878
|
+
|
|
523
879
|
if not self.is_fitted:
|
|
524
880
|
logger.warning("Saving unfitted DataProcessor")
|
|
525
881
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
882
|
+
target_path = resolve_save_path(
|
|
883
|
+
path=save_path,
|
|
884
|
+
default_dir=self.session.processor_dir,
|
|
885
|
+
default_name="processor",
|
|
886
|
+
suffix=".pkl",
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
# Prepare state dict
|
|
531
890
|
state = {
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
891
|
+
"numeric_features": self.numeric_features,
|
|
892
|
+
"sparse_features": self.sparse_features,
|
|
893
|
+
"sequence_features": self.sequence_features,
|
|
894
|
+
"target_features": self.target_features,
|
|
895
|
+
"is_fitted": self.is_fitted,
|
|
896
|
+
"scalers": self.scalers,
|
|
897
|
+
"label_encoders": self.label_encoders,
|
|
898
|
+
"target_encoders": self.target_encoders,
|
|
540
899
|
}
|
|
541
|
-
|
|
542
|
-
|
|
900
|
+
|
|
901
|
+
# Save with pickle
|
|
902
|
+
with open(target_path, "wb") as f:
|
|
543
903
|
pickle.dump(state, f)
|
|
544
|
-
|
|
545
|
-
logger.info(f"DataProcessor saved to {
|
|
904
|
+
|
|
905
|
+
logger.info(colorize(f"DataProcessor saved to: {target_path}", color="green"))
|
|
546
906
|
|
|
547
907
|
@classmethod
|
|
548
|
-
def load(cls,
|
|
908
|
+
def load(cls, load_path: str) -> 'DataProcessor':
|
|
549
909
|
logger = logging.getLogger()
|
|
550
910
|
|
|
551
|
-
with open(
|
|
911
|
+
with open(load_path, 'rb') as f:
|
|
552
912
|
state = pickle.load(f)
|
|
553
913
|
|
|
554
914
|
processor = cls()
|
|
@@ -561,7 +921,7 @@ class DataProcessor:
|
|
|
561
921
|
processor.label_encoders = state['label_encoders']
|
|
562
922
|
processor.target_encoders = state['target_encoders']
|
|
563
923
|
|
|
564
|
-
logger.info(f"DataProcessor loaded from {
|
|
924
|
+
logger.info(f"DataProcessor loaded from {load_path}")
|
|
565
925
|
return processor
|
|
566
926
|
|
|
567
927
|
def get_vocab_sizes(self) -> Dict[str, int]:
|
|
@@ -659,4 +1019,4 @@ class DataProcessor:
|
|
|
659
1019
|
|
|
660
1020
|
logger.info("")
|
|
661
1021
|
logger.info("")
|
|
662
|
-
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1022
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|