nextrec 0.4.32__py3-none-any.whl → 0.4.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +14 -16
- nextrec/basic/asserts.py +1 -22
- nextrec/basic/callback.py +2 -2
- nextrec/basic/features.py +6 -37
- nextrec/basic/heads.py +13 -1
- nextrec/basic/layers.py +9 -33
- nextrec/basic/loggers.py +3 -2
- nextrec/basic/metrics.py +85 -4
- nextrec/basic/model.py +19 -12
- nextrec/basic/summary.py +89 -42
- nextrec/cli.py +54 -41
- nextrec/data/preprocessor.py +74 -25
- nextrec/loss/grad_norm.py +78 -76
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/tree_base/base.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/config.py +1 -1
- nextrec/utils/console.py +1 -1
- nextrec/utils/torch_utils.py +63 -56
- nextrec/utils/types.py +43 -0
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/METADATA +4 -4
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/RECORD +27 -35
- nextrec/models/representation/autorec.py +0 -0
- nextrec/models/representation/bpr.py +0 -0
- nextrec/models/representation/cl4srec.py +0 -0
- nextrec/models/representation/lightgcn.py +0 -0
- nextrec/models/representation/mf.py +0 -0
- nextrec/models/representation/s3rec.py +0 -0
- nextrec/models/sequential/sasrec.py +0 -0
- nextrec/utils/feature.py +0 -29
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/WHEEL +0 -0
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/summary.py
CHANGED
|
@@ -8,6 +8,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
+
import inspect
|
|
11
12
|
import logging
|
|
12
13
|
from typing import Any, Literal
|
|
13
14
|
|
|
@@ -34,6 +35,7 @@ class SummarySet:
|
|
|
34
35
|
scheduler_name: str | None
|
|
35
36
|
scheduler_params: dict[str, Any]
|
|
36
37
|
loss_config: Any
|
|
38
|
+
loss_params: Any
|
|
37
39
|
loss_weights: Any
|
|
38
40
|
grad_norm: Any
|
|
39
41
|
embedding_l1_reg: float
|
|
@@ -73,7 +75,8 @@ class SummarySet:
|
|
|
73
75
|
def build_data_summary(
|
|
74
76
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
75
77
|
):
|
|
76
|
-
|
|
78
|
+
|
|
79
|
+
dataset = data_loader.dataset if data_loader is not None else None
|
|
77
80
|
|
|
78
81
|
train_size = get_data_length(dataset)
|
|
79
82
|
if train_size is None:
|
|
@@ -324,6 +327,73 @@ class SummarySet:
|
|
|
324
327
|
|
|
325
328
|
if hasattr(self, "loss_config"):
|
|
326
329
|
logger.info(f"Loss Function: {self.loss_config}")
|
|
330
|
+
|
|
331
|
+
loss_params_summary: list[str] = []
|
|
332
|
+
loss_fn = getattr(self, "loss_fn", None)
|
|
333
|
+
if loss_fn is not None:
|
|
334
|
+
loss_modules = (
|
|
335
|
+
list(loss_fn) if isinstance(loss_fn, (list, tuple)) else [loss_fn]
|
|
336
|
+
)
|
|
337
|
+
loss_config = getattr(self, "loss_config", None)
|
|
338
|
+
if isinstance(loss_config, list):
|
|
339
|
+
loss_names = loss_config
|
|
340
|
+
elif loss_config is not None:
|
|
341
|
+
loss_names = [loss_config] * len(loss_modules)
|
|
342
|
+
else:
|
|
343
|
+
loss_names = [None] * len(loss_modules)
|
|
344
|
+
|
|
345
|
+
loss_params = getattr(self, "loss_params", None)
|
|
346
|
+
if isinstance(loss_params, list):
|
|
347
|
+
explicit_params = loss_params
|
|
348
|
+
elif isinstance(loss_params, dict):
|
|
349
|
+
explicit_params = [loss_params] * len(loss_modules)
|
|
350
|
+
else:
|
|
351
|
+
explicit_params = [None] * len(loss_modules)
|
|
352
|
+
|
|
353
|
+
for idx, loss_module in enumerate(loss_modules):
|
|
354
|
+
params: dict[str, Any] = {}
|
|
355
|
+
explicit = (
|
|
356
|
+
explicit_params[idx] if idx < len(explicit_params) else None
|
|
357
|
+
)
|
|
358
|
+
if explicit:
|
|
359
|
+
params.update(explicit)
|
|
360
|
+
try:
|
|
361
|
+
signature = inspect.signature(loss_module.__class__.__init__)
|
|
362
|
+
except (TypeError, ValueError):
|
|
363
|
+
signature = None
|
|
364
|
+
if signature is not None:
|
|
365
|
+
for name, param in signature.parameters.items():
|
|
366
|
+
if name == "self" or name.startswith("_"):
|
|
367
|
+
continue
|
|
368
|
+
if hasattr(loss_module, name):
|
|
369
|
+
value = getattr(loss_module, name)
|
|
370
|
+
if callable(value):
|
|
371
|
+
continue
|
|
372
|
+
params.setdefault(name, value)
|
|
373
|
+
elif (
|
|
374
|
+
param.default is not inspect._empty
|
|
375
|
+
and param.default is not None
|
|
376
|
+
):
|
|
377
|
+
params.setdefault(name, param.default)
|
|
378
|
+
if not params:
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
loss_name = loss_names[idx] if idx < len(loss_names) else None
|
|
382
|
+
if len(loss_modules) > 1:
|
|
383
|
+
header = f" [{idx}]"
|
|
384
|
+
if loss_name is not None:
|
|
385
|
+
header = f"{header} {loss_name}"
|
|
386
|
+
loss_params_summary.append(header)
|
|
387
|
+
indent = " "
|
|
388
|
+
else:
|
|
389
|
+
indent = " "
|
|
390
|
+
for key, value in params.items():
|
|
391
|
+
loss_params_summary.append(f"{indent}{key:25s}: {value}")
|
|
392
|
+
|
|
393
|
+
if loss_params_summary:
|
|
394
|
+
logger.info("Loss Params:")
|
|
395
|
+
for line in loss_params_summary:
|
|
396
|
+
logger.info(line)
|
|
327
397
|
if hasattr(self, "loss_weights"):
|
|
328
398
|
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
329
399
|
if hasattr(self, "grad_norm"):
|
|
@@ -354,53 +424,30 @@ class SummarySet:
|
|
|
354
424
|
logger.info("")
|
|
355
425
|
logger.info(colorize("Data Summary", color="cyan", bold=True))
|
|
356
426
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
if
|
|
364
|
-
for target_name, details in label_distributions.items():
|
|
365
|
-
lines = details.get("lines", [])
|
|
366
|
-
logger.info(f"{target_name}:")
|
|
367
|
-
for label, value in lines:
|
|
368
|
-
logger.info(f" {format_kv(label, value)}")
|
|
369
|
-
|
|
370
|
-
dataloader_info = self.train_data_summary.get("dataloader")
|
|
371
|
-
if isinstance(dataloader_info, dict):
|
|
372
|
-
logger.info("Train DataLoader:")
|
|
373
|
-
for key in (
|
|
374
|
-
"batch_size",
|
|
375
|
-
"num_workers",
|
|
376
|
-
"pin_memory",
|
|
377
|
-
"persistent_workers",
|
|
378
|
-
"sampler",
|
|
379
|
-
):
|
|
380
|
-
if key in dataloader_info:
|
|
381
|
-
label = key.replace("_", " ").title()
|
|
382
|
-
logger.info(
|
|
383
|
-
format_kv(label, dataloader_info[key], indent=2)
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
if self.valid_data_summary:
|
|
387
|
-
if self.train_data_summary:
|
|
427
|
+
for label, data_summary in (
|
|
428
|
+
("Train", self.train_data_summary),
|
|
429
|
+
("Valid", self.valid_data_summary),
|
|
430
|
+
):
|
|
431
|
+
if not data_summary:
|
|
432
|
+
continue
|
|
433
|
+
if label == "Valid" and self.train_data_summary:
|
|
388
434
|
logger.info("")
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
435
|
+
sample_key = "train_samples" if label == "Train" else "valid_samples"
|
|
436
|
+
samples = data_summary.get(sample_key)
|
|
437
|
+
if samples is not None:
|
|
438
|
+
logger.info(format_kv(f"{label} Samples", f"{samples:,}"))
|
|
392
439
|
|
|
393
|
-
label_distributions =
|
|
440
|
+
label_distributions = data_summary.get("label_distributions")
|
|
394
441
|
if isinstance(label_distributions, dict):
|
|
395
442
|
for target_name, details in label_distributions.items():
|
|
396
443
|
lines = details.get("lines", [])
|
|
397
444
|
logger.info(f"{target_name}:")
|
|
398
|
-
for
|
|
399
|
-
logger.info(f" {format_kv(
|
|
445
|
+
for line_label, value in lines:
|
|
446
|
+
logger.info(f" {format_kv(line_label, value)}")
|
|
400
447
|
|
|
401
|
-
dataloader_info =
|
|
448
|
+
dataloader_info = data_summary.get("dataloader")
|
|
402
449
|
if isinstance(dataloader_info, dict):
|
|
403
|
-
logger.info("
|
|
450
|
+
logger.info(f"{label} DataLoader:")
|
|
404
451
|
for key in (
|
|
405
452
|
"batch_size",
|
|
406
453
|
"num_workers",
|
|
@@ -409,7 +456,7 @@ class SummarySet:
|
|
|
409
456
|
"sampler",
|
|
410
457
|
):
|
|
411
458
|
if key in dataloader_info:
|
|
412
|
-
|
|
459
|
+
field_label = key.replace("_", " ").title()
|
|
413
460
|
logger.info(
|
|
414
|
-
format_kv(
|
|
461
|
+
format_kv(field_label, dataloader_info[key], indent=2)
|
|
415
462
|
)
|
nextrec/cli.py
CHANGED
|
@@ -48,7 +48,7 @@ from nextrec.utils.data import (
|
|
|
48
48
|
read_yaml,
|
|
49
49
|
resolve_file_paths,
|
|
50
50
|
)
|
|
51
|
-
from nextrec.utils.
|
|
51
|
+
from nextrec.utils.torch_utils import to_list
|
|
52
52
|
|
|
53
53
|
logger = logging.getLogger(__name__)
|
|
54
54
|
|
|
@@ -152,16 +152,17 @@ def train_model(train_config_path: str) -> None:
|
|
|
152
152
|
)
|
|
153
153
|
if data_cfg.get("valid_ratio") is not None:
|
|
154
154
|
logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
|
|
155
|
-
if data_cfg.get("
|
|
155
|
+
if data_cfg.get("valid_path"):
|
|
156
156
|
logger.info(
|
|
157
157
|
format_kv(
|
|
158
158
|
"Validation path",
|
|
159
|
-
resolve_path(
|
|
160
|
-
data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
|
|
161
|
-
),
|
|
159
|
+
resolve_path(data_cfg.get("valid_path"), config_dir),
|
|
162
160
|
)
|
|
163
161
|
)
|
|
164
162
|
|
|
163
|
+
# Determine validation dataset path early for streaming split / fitting
|
|
164
|
+
val_data_path = data_cfg.get("valid_path")
|
|
165
|
+
|
|
165
166
|
if streaming:
|
|
166
167
|
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
167
168
|
log_kv_lines(
|
|
@@ -180,6 +181,34 @@ def train_model(train_config_path: str) -> None:
|
|
|
180
181
|
raise ValueError(f"Data file is empty: {first_file}") from exc
|
|
181
182
|
df_columns = list(first_chunk.columns)
|
|
182
183
|
|
|
184
|
+
# Decide training/validation file lists before fitting processor, to avoid
|
|
185
|
+
# leaking validation statistics into preprocessing (scalers/encoders).
|
|
186
|
+
streaming_train_files = file_paths
|
|
187
|
+
streaming_valid_ratio = data_cfg.get("valid_ratio")
|
|
188
|
+
if val_data_path:
|
|
189
|
+
streaming_valid_files = None
|
|
190
|
+
elif streaming_valid_ratio is not None:
|
|
191
|
+
ratio = float(streaming_valid_ratio)
|
|
192
|
+
if not (0 < ratio < 1):
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
|
|
195
|
+
)
|
|
196
|
+
total_files = len(file_paths)
|
|
197
|
+
if total_files < 2:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"[NextRec CLI Error] Must provide valid_path or increase the number of data files. At least 2 files are required for streaming validation split."
|
|
200
|
+
)
|
|
201
|
+
val_count = max(1, int(round(total_files * ratio)))
|
|
202
|
+
if val_count >= total_files:
|
|
203
|
+
val_count = total_files - 1
|
|
204
|
+
streaming_valid_files = file_paths[-val_count:]
|
|
205
|
+
streaming_train_files = file_paths[:-val_count]
|
|
206
|
+
logger.info(
|
|
207
|
+
f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
streaming_valid_files = None
|
|
211
|
+
|
|
183
212
|
else:
|
|
184
213
|
df = read_table(data_path, data_cfg.get("format"))
|
|
185
214
|
logger.info(format_kv("Rows", len(df)))
|
|
@@ -215,7 +244,15 @@ def train_model(train_config_path: str) -> None:
|
|
|
215
244
|
)
|
|
216
245
|
|
|
217
246
|
if streaming:
|
|
218
|
-
|
|
247
|
+
if file_type is None:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
"[NextRec CLI Error] Streaming mode requires a valid file_type"
|
|
250
|
+
)
|
|
251
|
+
processor.fit_from_files(
|
|
252
|
+
file_paths=streaming_train_files or file_paths,
|
|
253
|
+
file_type=file_type,
|
|
254
|
+
chunk_size=dataloader_chunk_size,
|
|
255
|
+
)
|
|
219
256
|
processed = None
|
|
220
257
|
df = None # type: ignore[assignment]
|
|
221
258
|
else:
|
|
@@ -232,34 +269,6 @@ def train_model(train_config_path: str) -> None:
|
|
|
232
269
|
sequence_names,
|
|
233
270
|
)
|
|
234
271
|
|
|
235
|
-
# Check if validation dataset path is specified
|
|
236
|
-
val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
|
|
237
|
-
if streaming:
|
|
238
|
-
if not file_paths:
|
|
239
|
-
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
240
|
-
streaming_train_files = file_paths
|
|
241
|
-
streaming_valid_ratio = data_cfg.get("valid_ratio")
|
|
242
|
-
if val_data_path:
|
|
243
|
-
streaming_valid_files = None
|
|
244
|
-
elif streaming_valid_ratio is not None:
|
|
245
|
-
ratio = float(streaming_valid_ratio)
|
|
246
|
-
if not (0 < ratio < 1):
|
|
247
|
-
raise ValueError(
|
|
248
|
-
f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
|
|
249
|
-
)
|
|
250
|
-
total_files = len(file_paths)
|
|
251
|
-
if total_files < 2:
|
|
252
|
-
raise ValueError(
|
|
253
|
-
"[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
|
|
254
|
-
)
|
|
255
|
-
val_count = max(1, int(round(total_files * ratio)))
|
|
256
|
-
if val_count >= total_files:
|
|
257
|
-
val_count = total_files - 1
|
|
258
|
-
streaming_valid_files = file_paths[-val_count:]
|
|
259
|
-
streaming_train_files = file_paths[:-val_count]
|
|
260
|
-
logger.info(
|
|
261
|
-
f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
|
|
262
|
-
)
|
|
263
272
|
train_data: Dict[str, Any]
|
|
264
273
|
valid_data: Dict[str, Any] | None
|
|
265
274
|
|
|
@@ -604,8 +613,13 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
604
613
|
"save_data_format", predict_cfg.get("save_format", "csv")
|
|
605
614
|
)
|
|
606
615
|
pred_name = predict_cfg.get("name", "pred")
|
|
607
|
-
|
|
608
|
-
|
|
616
|
+
pred_name_path = Path(pred_name)
|
|
617
|
+
if pred_name_path.is_absolute():
|
|
618
|
+
save_path = pred_name_path
|
|
619
|
+
if save_path.suffix == "":
|
|
620
|
+
save_path = save_path.with_suffix(f".{save_format}")
|
|
621
|
+
else:
|
|
622
|
+
save_path = checkpoint_base / "predictions" / f"{pred_name}.{save_format}"
|
|
609
623
|
|
|
610
624
|
start = time.time()
|
|
611
625
|
logger.info("")
|
|
@@ -620,11 +634,10 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
620
634
|
)
|
|
621
635
|
duration = time.time() - start
|
|
622
636
|
# When return_dataframe=False, result is the actual file path
|
|
623
|
-
|
|
624
|
-
result
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
)
|
|
637
|
+
if isinstance(result, (str, Path)):
|
|
638
|
+
output_path = Path(result)
|
|
639
|
+
else:
|
|
640
|
+
output_path = save_path
|
|
628
641
|
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
629
642
|
logger.info(f"Total time: {duration:.2f} seconds")
|
|
630
643
|
|
nextrec/data/preprocessor.py
CHANGED
|
@@ -566,35 +566,16 @@ class DataProcessor(FeatureSet):
|
|
|
566
566
|
return [str(v) for v in value]
|
|
567
567
|
return [str(value)]
|
|
568
568
|
|
|
569
|
-
def
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
Args:
|
|
574
|
-
path (str): File or directory path.
|
|
575
|
-
chunk_size (int): Number of rows per chunk.
|
|
576
|
-
|
|
577
|
-
Returns:
|
|
578
|
-
DataProcessor: Fitted DataProcessor instance.
|
|
579
|
-
"""
|
|
569
|
+
def fit_from_file_paths(
|
|
570
|
+
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
571
|
+
) -> "DataProcessor":
|
|
580
572
|
logger = logging.getLogger()
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
"Fitting DataProcessor (streaming path mode)...",
|
|
584
|
-
color="cyan",
|
|
585
|
-
bold=True,
|
|
586
|
-
)
|
|
587
|
-
)
|
|
588
|
-
for config in self.sparse_features.values():
|
|
589
|
-
config.pop("_min_freq_logged", None)
|
|
590
|
-
for config in self.sequence_features.values():
|
|
591
|
-
config.pop("_min_freq_logged", None)
|
|
592
|
-
file_paths, file_type = resolve_file_paths(path)
|
|
573
|
+
if not file_paths:
|
|
574
|
+
raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
|
|
593
575
|
if not check_streaming_support(file_type):
|
|
594
576
|
raise ValueError(
|
|
595
577
|
f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
|
|
596
|
-
"
|
|
597
|
-
"Use fit(dataframe) with in-memory data or convert the data format."
|
|
578
|
+
"Streaming fit only supports csv, parquet to avoid high memory usage."
|
|
598
579
|
)
|
|
599
580
|
|
|
600
581
|
numeric_acc = {}
|
|
@@ -636,6 +617,7 @@ class DataProcessor(FeatureSet):
|
|
|
636
617
|
target_values: Dict[str, set[Any]] = {
|
|
637
618
|
name: set() for name in self.target_features.keys()
|
|
638
619
|
}
|
|
620
|
+
|
|
639
621
|
missing_features = set()
|
|
640
622
|
for file_path in file_paths:
|
|
641
623
|
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
@@ -702,10 +684,12 @@ class DataProcessor(FeatureSet):
|
|
|
702
684
|
for name in self.target_features.keys() & columns:
|
|
703
685
|
vals = chunk[name].dropna().tolist()
|
|
704
686
|
target_values[name].update(vals)
|
|
687
|
+
|
|
705
688
|
if missing_features:
|
|
706
689
|
logger.warning(
|
|
707
690
|
f"The following configured features were not found in provided files: {sorted(missing_features)}"
|
|
708
691
|
)
|
|
692
|
+
|
|
709
693
|
# finalize numeric scalers
|
|
710
694
|
for name, config in self.numeric_features.items():
|
|
711
695
|
acc = numeric_acc[name]
|
|
@@ -895,6 +879,71 @@ class DataProcessor(FeatureSet):
|
|
|
895
879
|
)
|
|
896
880
|
return self
|
|
897
881
|
|
|
882
|
+
def fit_from_files(
|
|
883
|
+
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
884
|
+
) -> "DataProcessor":
|
|
885
|
+
"""Fit processor statistics by streaming an explicit list of files.
|
|
886
|
+
|
|
887
|
+
This is useful when you want to fit statistics on training files only (exclude
|
|
888
|
+
validation files) in streaming mode.
|
|
889
|
+
"""
|
|
890
|
+
logger = logging.getLogger()
|
|
891
|
+
logger.info(
|
|
892
|
+
colorize(
|
|
893
|
+
"Fitting DataProcessor (streaming files mode)...",
|
|
894
|
+
color="cyan",
|
|
895
|
+
bold=True,
|
|
896
|
+
)
|
|
897
|
+
)
|
|
898
|
+
for config in self.sparse_features.values():
|
|
899
|
+
config.pop("_min_freq_logged", None)
|
|
900
|
+
for config in self.sequence_features.values():
|
|
901
|
+
config.pop("_min_freq_logged", None)
|
|
902
|
+
uses_robust = any(
|
|
903
|
+
cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
|
|
904
|
+
)
|
|
905
|
+
if uses_robust:
|
|
906
|
+
logger.warning(
|
|
907
|
+
"Robust scaler requires full data; loading provided files into memory. "
|
|
908
|
+
"Consider smaller chunk_size or different scaler if memory is limited."
|
|
909
|
+
)
|
|
910
|
+
frames = [read_table(p, file_type) for p in file_paths]
|
|
911
|
+
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
912
|
+
return self.fit(df)
|
|
913
|
+
return self.fit_from_file_paths(
|
|
914
|
+
file_paths=file_paths, file_type=file_type, chunk_size=chunk_size
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
|
|
918
|
+
"""
|
|
919
|
+
Fit processor statistics by streaming files to reduce memory usage.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
path (str): File or directory path.
|
|
923
|
+
chunk_size (int): Number of rows per chunk.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
DataProcessor: Fitted DataProcessor instance.
|
|
927
|
+
"""
|
|
928
|
+
logger = logging.getLogger()
|
|
929
|
+
logger.info(
|
|
930
|
+
colorize(
|
|
931
|
+
"Fitting DataProcessor (streaming path mode)...",
|
|
932
|
+
color="cyan",
|
|
933
|
+
bold=True,
|
|
934
|
+
)
|
|
935
|
+
)
|
|
936
|
+
for config in self.sparse_features.values():
|
|
937
|
+
config.pop("_min_freq_logged", None)
|
|
938
|
+
for config in self.sequence_features.values():
|
|
939
|
+
config.pop("_min_freq_logged", None)
|
|
940
|
+
file_paths, file_type = resolve_file_paths(path)
|
|
941
|
+
return self.fit_from_file_paths(
|
|
942
|
+
file_paths=file_paths,
|
|
943
|
+
file_type=file_type,
|
|
944
|
+
chunk_size=chunk_size,
|
|
945
|
+
)
|
|
946
|
+
|
|
898
947
|
@overload
|
|
899
948
|
def transform_in_memory(
|
|
900
949
|
self,
|