nextrec 0.4.32__py3-none-any.whl → 0.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/summary.py CHANGED
@@ -8,6 +8,7 @@ Author: Yang Zhou,zyaztec@gmail.com
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ import inspect
11
12
  import logging
12
13
  from typing import Any, Literal
13
14
 
@@ -34,6 +35,7 @@ class SummarySet:
34
35
  scheduler_name: str | None
35
36
  scheduler_params: dict[str, Any]
36
37
  loss_config: Any
38
+ loss_params: Any
37
39
  loss_weights: Any
38
40
  grad_norm: Any
39
41
  embedding_l1_reg: float
@@ -73,7 +75,8 @@ class SummarySet:
73
75
  def build_data_summary(
74
76
  self, data: Any, data_loader: DataLoader | None, sample_key: str
75
77
  ):
76
- dataset = data_loader.dataset if data_loader else None
78
+
79
+ dataset = data_loader.dataset if data_loader is not None else None
77
80
 
78
81
  train_size = get_data_length(dataset)
79
82
  if train_size is None:
@@ -324,6 +327,73 @@ class SummarySet:
324
327
 
325
328
  if hasattr(self, "loss_config"):
326
329
  logger.info(f"Loss Function: {self.loss_config}")
330
+
331
+ loss_params_summary: list[str] = []
332
+ loss_fn = getattr(self, "loss_fn", None)
333
+ if loss_fn is not None:
334
+ loss_modules = (
335
+ list(loss_fn) if isinstance(loss_fn, (list, tuple)) else [loss_fn]
336
+ )
337
+ loss_config = getattr(self, "loss_config", None)
338
+ if isinstance(loss_config, list):
339
+ loss_names = loss_config
340
+ elif loss_config is not None:
341
+ loss_names = [loss_config] * len(loss_modules)
342
+ else:
343
+ loss_names = [None] * len(loss_modules)
344
+
345
+ loss_params = getattr(self, "loss_params", None)
346
+ if isinstance(loss_params, list):
347
+ explicit_params = loss_params
348
+ elif isinstance(loss_params, dict):
349
+ explicit_params = [loss_params] * len(loss_modules)
350
+ else:
351
+ explicit_params = [None] * len(loss_modules)
352
+
353
+ for idx, loss_module in enumerate(loss_modules):
354
+ params: dict[str, Any] = {}
355
+ explicit = (
356
+ explicit_params[idx] if idx < len(explicit_params) else None
357
+ )
358
+ if explicit:
359
+ params.update(explicit)
360
+ try:
361
+ signature = inspect.signature(loss_module.__class__.__init__)
362
+ except (TypeError, ValueError):
363
+ signature = None
364
+ if signature is not None:
365
+ for name, param in signature.parameters.items():
366
+ if name == "self" or name.startswith("_"):
367
+ continue
368
+ if hasattr(loss_module, name):
369
+ value = getattr(loss_module, name)
370
+ if callable(value):
371
+ continue
372
+ params.setdefault(name, value)
373
+ elif (
374
+ param.default is not inspect._empty
375
+ and param.default is not None
376
+ ):
377
+ params.setdefault(name, param.default)
378
+ if not params:
379
+ continue
380
+
381
+ loss_name = loss_names[idx] if idx < len(loss_names) else None
382
+ if len(loss_modules) > 1:
383
+ header = f" [{idx}]"
384
+ if loss_name is not None:
385
+ header = f"{header} {loss_name}"
386
+ loss_params_summary.append(header)
387
+ indent = " "
388
+ else:
389
+ indent = " "
390
+ for key, value in params.items():
391
+ loss_params_summary.append(f"{indent}{key:25s}: {value}")
392
+
393
+ if loss_params_summary:
394
+ logger.info("Loss Params:")
395
+ for line in loss_params_summary:
396
+ logger.info(line)
327
397
  if hasattr(self, "loss_weights"):
328
398
  logger.info(f"Loss Weights: {self.loss_weights}")
329
399
  if hasattr(self, "grad_norm"):
@@ -354,53 +424,30 @@ class SummarySet:
354
424
  logger.info("")
355
425
  logger.info(colorize("Data Summary", color="cyan", bold=True))
356
426
  logger.info(colorize("-" * 80, color="cyan"))
357
- if self.train_data_summary:
358
- train_samples = self.train_data_summary.get("train_samples")
359
- if train_samples is not None:
360
- logger.info(format_kv("Train Samples", f"{train_samples:,}"))
361
-
362
- label_distributions = self.train_data_summary.get("label_distributions")
363
- if isinstance(label_distributions, dict):
364
- for target_name, details in label_distributions.items():
365
- lines = details.get("lines", [])
366
- logger.info(f"{target_name}:")
367
- for label, value in lines:
368
- logger.info(f" {format_kv(label, value)}")
369
-
370
- dataloader_info = self.train_data_summary.get("dataloader")
371
- if isinstance(dataloader_info, dict):
372
- logger.info("Train DataLoader:")
373
- for key in (
374
- "batch_size",
375
- "num_workers",
376
- "pin_memory",
377
- "persistent_workers",
378
- "sampler",
379
- ):
380
- if key in dataloader_info:
381
- label = key.replace("_", " ").title()
382
- logger.info(
383
- format_kv(label, dataloader_info[key], indent=2)
384
- )
385
-
386
- if self.valid_data_summary:
387
- if self.train_data_summary:
427
+ for label, data_summary in (
428
+ ("Train", self.train_data_summary),
429
+ ("Valid", self.valid_data_summary),
430
+ ):
431
+ if not data_summary:
432
+ continue
433
+ if label == "Valid" and self.train_data_summary:
388
434
  logger.info("")
389
- valid_samples = self.valid_data_summary.get("valid_samples")
390
- if valid_samples is not None:
391
- logger.info(format_kv("Valid Samples", f"{valid_samples:,}"))
435
+ sample_key = "train_samples" if label == "Train" else "valid_samples"
436
+ samples = data_summary.get(sample_key)
437
+ if samples is not None:
438
+ logger.info(format_kv(f"{label} Samples", f"{samples:,}"))
392
439
 
393
- label_distributions = self.valid_data_summary.get("label_distributions")
440
+ label_distributions = data_summary.get("label_distributions")
394
441
  if isinstance(label_distributions, dict):
395
442
  for target_name, details in label_distributions.items():
396
443
  lines = details.get("lines", [])
397
444
  logger.info(f"{target_name}:")
398
- for label, value in lines:
399
- logger.info(f" {format_kv(label, value)}")
445
+ for line_label, value in lines:
446
+ logger.info(f" {format_kv(line_label, value)}")
400
447
 
401
- dataloader_info = self.valid_data_summary.get("dataloader")
448
+ dataloader_info = data_summary.get("dataloader")
402
449
  if isinstance(dataloader_info, dict):
403
- logger.info("Valid DataLoader:")
450
+ logger.info(f"{label} DataLoader:")
404
451
  for key in (
405
452
  "batch_size",
406
453
  "num_workers",
@@ -409,7 +456,7 @@ class SummarySet:
409
456
  "sampler",
410
457
  ):
411
458
  if key in dataloader_info:
412
- label = key.replace("_", " ").title()
459
+ field_label = key.replace("_", " ").title()
413
460
  logger.info(
414
- format_kv(label, dataloader_info[key], indent=2)
461
+ format_kv(field_label, dataloader_info[key], indent=2)
415
462
  )
nextrec/cli.py CHANGED
@@ -48,7 +48,7 @@ from nextrec.utils.data import (
48
48
  read_yaml,
49
49
  resolve_file_paths,
50
50
  )
51
- from nextrec.utils.feature import to_list
51
+ from nextrec.utils.torch_utils import to_list
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
@@ -152,16 +152,17 @@ def train_model(train_config_path: str) -> None:
152
152
  )
153
153
  if data_cfg.get("valid_ratio") is not None:
154
154
  logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
155
- if data_cfg.get("val_path") or data_cfg.get("valid_path"):
155
+ if data_cfg.get("valid_path"):
156
156
  logger.info(
157
157
  format_kv(
158
158
  "Validation path",
159
- resolve_path(
160
- data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
161
- ),
159
+ resolve_path(data_cfg.get("valid_path"), config_dir),
162
160
  )
163
161
  )
164
162
 
163
+ # Determine validation dataset path early for streaming split / fitting
164
+ val_data_path = data_cfg.get("valid_path")
165
+
165
166
  if streaming:
166
167
  file_paths, file_type = resolve_file_paths(str(data_path))
167
168
  log_kv_lines(
@@ -180,6 +181,34 @@ def train_model(train_config_path: str) -> None:
180
181
  raise ValueError(f"Data file is empty: {first_file}") from exc
181
182
  df_columns = list(first_chunk.columns)
182
183
 
184
+ # Decide training/validation file lists before fitting processor, to avoid
185
+ # leaking validation statistics into preprocessing (scalers/encoders).
186
+ streaming_train_files = file_paths
187
+ streaming_valid_ratio = data_cfg.get("valid_ratio")
188
+ if val_data_path:
189
+ streaming_valid_files = None
190
+ elif streaming_valid_ratio is not None:
191
+ ratio = float(streaming_valid_ratio)
192
+ if not (0 < ratio < 1):
193
+ raise ValueError(
194
+ f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
195
+ )
196
+ total_files = len(file_paths)
197
+ if total_files < 2:
198
+ raise ValueError(
199
+ "[NextRec CLI Error] Must provide valid_path or increase the number of data files. At least 2 files are required for streaming validation split."
200
+ )
201
+ val_count = max(1, int(round(total_files * ratio)))
202
+ if val_count >= total_files:
203
+ val_count = total_files - 1
204
+ streaming_valid_files = file_paths[-val_count:]
205
+ streaming_train_files = file_paths[:-val_count]
206
+ logger.info(
207
+ f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
208
+ )
209
+ else:
210
+ streaming_valid_files = None
211
+
183
212
  else:
184
213
  df = read_table(data_path, data_cfg.get("format"))
185
214
  logger.info(format_kv("Rows", len(df)))
@@ -215,7 +244,15 @@ def train_model(train_config_path: str) -> None:
215
244
  )
216
245
 
217
246
  if streaming:
218
- processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
247
+ if file_type is None:
248
+ raise ValueError(
249
+ "[NextRec CLI Error] Streaming mode requires a valid file_type"
250
+ )
251
+ processor.fit_from_files(
252
+ file_paths=streaming_train_files or file_paths,
253
+ file_type=file_type,
254
+ chunk_size=dataloader_chunk_size,
255
+ )
219
256
  processed = None
220
257
  df = None # type: ignore[assignment]
221
258
  else:
@@ -232,34 +269,6 @@ def train_model(train_config_path: str) -> None:
232
269
  sequence_names,
233
270
  )
234
271
 
235
- # Check if validation dataset path is specified
236
- val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
237
- if streaming:
238
- if not file_paths:
239
- file_paths, file_type = resolve_file_paths(str(data_path))
240
- streaming_train_files = file_paths
241
- streaming_valid_ratio = data_cfg.get("valid_ratio")
242
- if val_data_path:
243
- streaming_valid_files = None
244
- elif streaming_valid_ratio is not None:
245
- ratio = float(streaming_valid_ratio)
246
- if not (0 < ratio < 1):
247
- raise ValueError(
248
- f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
249
- )
250
- total_files = len(file_paths)
251
- if total_files < 2:
252
- raise ValueError(
253
- "[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
254
- )
255
- val_count = max(1, int(round(total_files * ratio)))
256
- if val_count >= total_files:
257
- val_count = total_files - 1
258
- streaming_valid_files = file_paths[-val_count:]
259
- streaming_train_files = file_paths[:-val_count]
260
- logger.info(
261
- f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
262
- )
263
272
  train_data: Dict[str, Any]
264
273
  valid_data: Dict[str, Any] | None
265
274
 
@@ -604,8 +613,13 @@ def predict_model(predict_config_path: str) -> None:
604
613
  "save_data_format", predict_cfg.get("save_format", "csv")
605
614
  )
606
615
  pred_name = predict_cfg.get("name", "pred")
607
-
608
- save_path = checkpoint_base / "predictions" / f"{pred_name}.{save_format}"
616
+ pred_name_path = Path(pred_name)
617
+ if pred_name_path.is_absolute():
618
+ save_path = pred_name_path
619
+ if save_path.suffix == "":
620
+ save_path = save_path.with_suffix(f".{save_format}")
621
+ else:
622
+ save_path = checkpoint_base / "predictions" / f"{pred_name}.{save_format}"
609
623
 
610
624
  start = time.time()
611
625
  logger.info("")
@@ -620,11 +634,10 @@ def predict_model(predict_config_path: str) -> None:
620
634
  )
621
635
  duration = time.time() - start
622
636
  # When return_dataframe=False, result is the actual file path
623
- output_path = (
624
- result
625
- if isinstance(result, Path)
626
- else checkpoint_base / "predictions" / save_path
627
- )
637
+ if isinstance(result, (str, Path)):
638
+ output_path = Path(result)
639
+ else:
640
+ output_path = save_path
628
641
  logger.info(f"Prediction completed, results saved to: {output_path}")
629
642
  logger.info(f"Total time: {duration:.2f} seconds")
630
643
 
@@ -566,35 +566,16 @@ class DataProcessor(FeatureSet):
566
566
  return [str(v) for v in value]
567
567
  return [str(value)]
568
568
 
569
- def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
570
- """
571
- Fit processor statistics by streaming files to reduce memory usage.
572
-
573
- Args:
574
- path (str): File or directory path.
575
- chunk_size (int): Number of rows per chunk.
576
-
577
- Returns:
578
- DataProcessor: Fitted DataProcessor instance.
579
- """
569
+ def fit_from_file_paths(
570
+ self, file_paths: list[str], file_type: str, chunk_size: int
571
+ ) -> "DataProcessor":
580
572
  logger = logging.getLogger()
581
- logger.info(
582
- colorize(
583
- "Fitting DataProcessor (streaming path mode)...",
584
- color="cyan",
585
- bold=True,
586
- )
587
- )
588
- for config in self.sparse_features.values():
589
- config.pop("_min_freq_logged", None)
590
- for config in self.sequence_features.values():
591
- config.pop("_min_freq_logged", None)
592
- file_paths, file_type = resolve_file_paths(path)
573
+ if not file_paths:
574
+ raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
593
575
  if not check_streaming_support(file_type):
594
576
  raise ValueError(
595
577
  f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
596
- "fit_from_path only supports streaming formats (csv, parquet) to avoid high memory usage. "
597
- "Use fit(dataframe) with in-memory data or convert the data format."
578
+ "Streaming fit only supports csv, parquet to avoid high memory usage."
598
579
  )
599
580
 
600
581
  numeric_acc = {}
@@ -636,6 +617,7 @@ class DataProcessor(FeatureSet):
636
617
  target_values: Dict[str, set[Any]] = {
637
618
  name: set() for name in self.target_features.keys()
638
619
  }
620
+
639
621
  missing_features = set()
640
622
  for file_path in file_paths:
641
623
  for chunk in iter_file_chunks(file_path, file_type, chunk_size):
@@ -702,10 +684,12 @@ class DataProcessor(FeatureSet):
702
684
  for name in self.target_features.keys() & columns:
703
685
  vals = chunk[name].dropna().tolist()
704
686
  target_values[name].update(vals)
687
+
705
688
  if missing_features:
706
689
  logger.warning(
707
690
  f"The following configured features were not found in provided files: {sorted(missing_features)}"
708
691
  )
692
+
709
693
  # finalize numeric scalers
710
694
  for name, config in self.numeric_features.items():
711
695
  acc = numeric_acc[name]
@@ -895,6 +879,71 @@ class DataProcessor(FeatureSet):
895
879
  )
896
880
  return self
897
881
 
882
+ def fit_from_files(
883
+ self, file_paths: list[str], file_type: str, chunk_size: int
884
+ ) -> "DataProcessor":
885
+ """Fit processor statistics by streaming an explicit list of files.
886
+
887
+ This is useful when you want to fit statistics on training files only (exclude
888
+ validation files) in streaming mode.
889
+ """
890
+ logger = logging.getLogger()
891
+ logger.info(
892
+ colorize(
893
+ "Fitting DataProcessor (streaming files mode)...",
894
+ color="cyan",
895
+ bold=True,
896
+ )
897
+ )
898
+ for config in self.sparse_features.values():
899
+ config.pop("_min_freq_logged", None)
900
+ for config in self.sequence_features.values():
901
+ config.pop("_min_freq_logged", None)
902
+ uses_robust = any(
903
+ cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
904
+ )
905
+ if uses_robust:
906
+ logger.warning(
907
+ "Robust scaler requires full data; loading provided files into memory. "
908
+ "Consider smaller chunk_size or different scaler if memory is limited."
909
+ )
910
+ frames = [read_table(p, file_type) for p in file_paths]
911
+ df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
912
+ return self.fit(df)
913
+ return self.fit_from_file_paths(
914
+ file_paths=file_paths, file_type=file_type, chunk_size=chunk_size
915
+ )
916
+
917
+ def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
918
+ """
919
+ Fit processor statistics by streaming files to reduce memory usage.
920
+
921
+ Args:
922
+ path (str): File or directory path.
923
+ chunk_size (int): Number of rows per chunk.
924
+
925
+ Returns:
926
+ DataProcessor: Fitted DataProcessor instance.
927
+ """
928
+ logger = logging.getLogger()
929
+ logger.info(
930
+ colorize(
931
+ "Fitting DataProcessor (streaming path mode)...",
932
+ color="cyan",
933
+ bold=True,
934
+ )
935
+ )
936
+ for config in self.sparse_features.values():
937
+ config.pop("_min_freq_logged", None)
938
+ for config in self.sequence_features.values():
939
+ config.pop("_min_freq_logged", None)
940
+ file_paths, file_type = resolve_file_paths(path)
941
+ return self.fit_from_file_paths(
942
+ file_paths=file_paths,
943
+ file_type=file_type,
944
+ chunk_size=chunk_size,
945
+ )
946
+
898
947
  @overload
899
948
  def transform_in_memory(
900
949
  self,