nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +10 -18
  3. nextrec/basic/asserts.py +1 -22
  4. nextrec/basic/callback.py +2 -2
  5. nextrec/basic/features.py +6 -37
  6. nextrec/basic/heads.py +13 -1
  7. nextrec/basic/layers.py +33 -123
  8. nextrec/basic/loggers.py +3 -2
  9. nextrec/basic/metrics.py +85 -4
  10. nextrec/basic/model.py +518 -7
  11. nextrec/basic/summary.py +88 -42
  12. nextrec/cli.py +117 -30
  13. nextrec/data/data_processing.py +8 -13
  14. nextrec/data/preprocessor.py +449 -844
  15. nextrec/loss/grad_norm.py +78 -76
  16. nextrec/models/multi_task/ple.py +1 -0
  17. nextrec/models/multi_task/share_bottom.py +1 -0
  18. nextrec/models/ranking/afm.py +4 -9
  19. nextrec/models/ranking/dien.py +7 -8
  20. nextrec/models/ranking/ffm.py +2 -2
  21. nextrec/models/retrieval/sdm.py +1 -2
  22. nextrec/models/sequential/hstu.py +0 -2
  23. nextrec/models/tree_base/base.py +1 -1
  24. nextrec/utils/__init__.py +2 -1
  25. nextrec/utils/config.py +1 -1
  26. nextrec/utils/console.py +1 -1
  27. nextrec/utils/onnx_utils.py +252 -0
  28. nextrec/utils/torch_utils.py +63 -56
  29. nextrec/utils/types.py +43 -0
  30. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
  31. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
  32. nextrec/models/multi_task/[pre]star.py +0 -192
  33. nextrec/models/representation/autorec.py +0 -0
  34. nextrec/models/representation/bpr.py +0 -0
  35. nextrec/models/representation/cl4srec.py +0 -0
  36. nextrec/models/representation/lightgcn.py +0 -0
  37. nextrec/models/representation/mf.py +0 -0
  38. nextrec/models/representation/s3rec.py +0 -0
  39. nextrec/models/sequential/sasrec.py +0 -0
  40. nextrec/utils/feature.py +0 -29
  41. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
  42. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
  43. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/summary.py CHANGED
@@ -8,6 +8,7 @@ Author: Yang Zhou,zyaztec@gmail.com
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ import inspect
11
12
  import logging
12
13
  from typing import Any, Literal
13
14
 
@@ -34,6 +35,7 @@ class SummarySet:
34
35
  scheduler_name: str | None
35
36
  scheduler_params: dict[str, Any]
36
37
  loss_config: Any
38
+ loss_params: Any
37
39
  loss_weights: Any
38
40
  grad_norm: Any
39
41
  embedding_l1_reg: float
@@ -73,7 +75,7 @@ class SummarySet:
73
75
  def build_data_summary(
74
76
  self, data: Any, data_loader: DataLoader | None, sample_key: str
75
77
  ):
76
-
78
+
77
79
  dataset = data_loader.dataset if data_loader is not None else None
78
80
 
79
81
  train_size = get_data_length(dataset)
@@ -325,6 +327,73 @@ class SummarySet:
325
327
 
326
328
  if hasattr(self, "loss_config"):
327
329
  logger.info(f"Loss Function: {self.loss_config}")
330
+
331
+ loss_params_summary: list[str] = []
332
+ loss_fn = getattr(self, "loss_fn", None)
333
+ if loss_fn is not None:
334
+ loss_modules = (
335
+ list(loss_fn) if isinstance(loss_fn, (list, tuple)) else [loss_fn]
336
+ )
337
+ loss_config = getattr(self, "loss_config", None)
338
+ if isinstance(loss_config, list):
339
+ loss_names = loss_config
340
+ elif loss_config is not None:
341
+ loss_names = [loss_config] * len(loss_modules)
342
+ else:
343
+ loss_names = [None] * len(loss_modules)
344
+
345
+ loss_params = getattr(self, "loss_params", None)
346
+ if isinstance(loss_params, list):
347
+ explicit_params = loss_params
348
+ elif isinstance(loss_params, dict):
349
+ explicit_params = [loss_params] * len(loss_modules)
350
+ else:
351
+ explicit_params = [None] * len(loss_modules)
352
+
353
+ for idx, loss_module in enumerate(loss_modules):
354
+ params: dict[str, Any] = {}
355
+ explicit = (
356
+ explicit_params[idx] if idx < len(explicit_params) else None
357
+ )
358
+ if explicit:
359
+ params.update(explicit)
360
+ try:
361
+ signature = inspect.signature(loss_module.__class__.__init__)
362
+ except (TypeError, ValueError):
363
+ signature = None
364
+ if signature is not None:
365
+ for name, param in signature.parameters.items():
366
+ if name == "self" or name.startswith("_"):
367
+ continue
368
+ if hasattr(loss_module, name):
369
+ value = getattr(loss_module, name)
370
+ if callable(value):
371
+ continue
372
+ params.setdefault(name, value)
373
+ elif (
374
+ param.default is not inspect._empty
375
+ and param.default is not None
376
+ ):
377
+ params.setdefault(name, param.default)
378
+ if not params:
379
+ continue
380
+
381
+ loss_name = loss_names[idx] if idx < len(loss_names) else None
382
+ if len(loss_modules) > 1:
383
+ header = f" [{idx}]"
384
+ if loss_name is not None:
385
+ header = f"{header} {loss_name}"
386
+ loss_params_summary.append(header)
387
+ indent = " "
388
+ else:
389
+ indent = " "
390
+ for key, value in params.items():
391
+ loss_params_summary.append(f"{indent}{key:25s}: {value}")
392
+
393
+ if loss_params_summary:
394
+ logger.info("Loss Params:")
395
+ for line in loss_params_summary:
396
+ logger.info(line)
328
397
  if hasattr(self, "loss_weights"):
329
398
  logger.info(f"Loss Weights: {self.loss_weights}")
330
399
  if hasattr(self, "grad_norm"):
@@ -355,53 +424,30 @@ class SummarySet:
355
424
  logger.info("")
356
425
  logger.info(colorize("Data Summary", color="cyan", bold=True))
357
426
  logger.info(colorize("-" * 80, color="cyan"))
358
- if self.train_data_summary:
359
- train_samples = self.train_data_summary.get("train_samples")
360
- if train_samples is not None:
361
- logger.info(format_kv("Train Samples", f"{train_samples:,}"))
362
-
363
- label_distributions = self.train_data_summary.get("label_distributions")
364
- if isinstance(label_distributions, dict):
365
- for target_name, details in label_distributions.items():
366
- lines = details.get("lines", [])
367
- logger.info(f"{target_name}:")
368
- for label, value in lines:
369
- logger.info(f" {format_kv(label, value)}")
370
-
371
- dataloader_info = self.train_data_summary.get("dataloader")
372
- if isinstance(dataloader_info, dict):
373
- logger.info("Train DataLoader:")
374
- for key in (
375
- "batch_size",
376
- "num_workers",
377
- "pin_memory",
378
- "persistent_workers",
379
- "sampler",
380
- ):
381
- if key in dataloader_info:
382
- label = key.replace("_", " ").title()
383
- logger.info(
384
- format_kv(label, dataloader_info[key], indent=2)
385
- )
386
-
387
- if self.valid_data_summary:
388
- if self.train_data_summary:
427
+ for label, data_summary in (
428
+ ("Train", self.train_data_summary),
429
+ ("Valid", self.valid_data_summary),
430
+ ):
431
+ if not data_summary:
432
+ continue
433
+ if label == "Valid" and self.train_data_summary:
389
434
  logger.info("")
390
- valid_samples = self.valid_data_summary.get("valid_samples")
391
- if valid_samples is not None:
392
- logger.info(format_kv("Valid Samples", f"{valid_samples:,}"))
435
+ sample_key = "train_samples" if label == "Train" else "valid_samples"
436
+ samples = data_summary.get(sample_key)
437
+ if samples is not None:
438
+ logger.info(format_kv(f"{label} Samples", f"{samples:,}"))
393
439
 
394
- label_distributions = self.valid_data_summary.get("label_distributions")
440
+ label_distributions = data_summary.get("label_distributions")
395
441
  if isinstance(label_distributions, dict):
396
442
  for target_name, details in label_distributions.items():
397
443
  lines = details.get("lines", [])
398
444
  logger.info(f"{target_name}:")
399
- for label, value in lines:
400
- logger.info(f" {format_kv(label, value)}")
445
+ for line_label, value in lines:
446
+ logger.info(f" {format_kv(line_label, value)}")
401
447
 
402
- dataloader_info = self.valid_data_summary.get("dataloader")
448
+ dataloader_info = data_summary.get("dataloader")
403
449
  if isinstance(dataloader_info, dict):
404
- logger.info("Valid DataLoader:")
450
+ logger.info(f"{label} DataLoader:")
405
451
  for key in (
406
452
  "batch_size",
407
453
  "num_workers",
@@ -410,7 +456,7 @@ class SummarySet:
410
456
  "sampler",
411
457
  ):
412
458
  if key in dataloader_info:
413
- label = key.replace("_", " ").title()
459
+ field_label = key.replace("_", " ").title()
414
460
  logger.info(
415
- format_kv(label, dataloader_info[key], indent=2)
461
+ format_kv(field_label, dataloader_info[key], indent=2)
416
462
  )
nextrec/cli.py CHANGED
@@ -48,7 +48,7 @@ from nextrec.utils.data import (
48
48
  read_yaml,
49
49
  resolve_file_paths,
50
50
  )
51
- from nextrec.utils.feature import to_list
51
+ from nextrec.utils.torch_utils import to_list
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
@@ -156,9 +156,7 @@ def train_model(train_config_path: str) -> None:
156
156
  logger.info(
157
157
  format_kv(
158
158
  "Validation path",
159
- resolve_path(
160
- data_cfg.get("valid_path"), config_dir
161
- ),
159
+ resolve_path(data_cfg.get("valid_path"), config_dir),
162
160
  )
163
161
  )
164
162
 
@@ -247,7 +245,9 @@ def train_model(train_config_path: str) -> None:
247
245
 
248
246
  if streaming:
249
247
  if file_type is None:
250
- raise ValueError("[NextRec CLI Error] Streaming mode requires a valid file_type")
248
+ raise ValueError(
249
+ "[NextRec CLI Error] Streaming mode requires a valid file_type"
250
+ )
251
251
  processor.fit_from_files(
252
252
  file_paths=streaming_train_files or file_paths,
253
253
  file_type=file_type,
@@ -422,6 +422,49 @@ def train_model(train_config_path: str) -> None:
422
422
  note=train_cfg.get("note"),
423
423
  )
424
424
 
425
+ export_cfg = train_cfg.get("export_onnx")
426
+ if export_cfg is None:
427
+ export_cfg = cfg.get("export_onnx")
428
+ export_enabled = False
429
+ export_options: dict[str, Any] = {}
430
+ if isinstance(export_cfg, bool):
431
+ export_enabled = export_cfg
432
+ elif isinstance(export_cfg, dict):
433
+ export_options = export_cfg
434
+ export_enabled = bool(export_cfg.get("enable", False))
435
+
436
+ if export_enabled:
437
+ log_cli_section("ONNX Export")
438
+ onnx_path = None
439
+ if export_options.get("path") or export_options.get("save_path"):
440
+ logger.warning(
441
+ "[NextRec CLI Warning] export_onnx.path/save_path is deprecated; "
442
+ "ONNX will be saved to best/checkpoint paths."
443
+ )
444
+ onnx_best_path = Path(model.best_path).with_suffix(".onnx")
445
+ onnx_ckpt_path = Path(model.checkpoint_path).with_suffix(".onnx")
446
+ onnx_batch_size = export_options.get("batch_size", 1)
447
+ onnx_opset = export_options.get("opset_version", 18)
448
+ log_kv_lines(
449
+ [
450
+ ("ONNX best path", onnx_best_path),
451
+ ("ONNX checkpoint path", onnx_ckpt_path),
452
+ ("Batch size", onnx_batch_size),
453
+ ("Opset", onnx_opset),
454
+ ("Dynamic batch", False),
455
+ ]
456
+ )
457
+ model.export_onnx(
458
+ save_path=onnx_best_path,
459
+ batch_size=onnx_batch_size,
460
+ opset_version=onnx_opset,
461
+ )
462
+ model.export_onnx(
463
+ save_path=onnx_ckpt_path,
464
+ batch_size=onnx_batch_size,
465
+ opset_version=onnx_opset,
466
+ )
467
+
425
468
 
426
469
  def predict_model(predict_config_path: str) -> None:
427
470
  """
@@ -492,12 +535,16 @@ def predict_model(predict_config_path: str) -> None:
492
535
  # Load checkpoint and ensure required parameters are passed
493
536
  checkpoint_base = Path(session_dir)
494
537
  if checkpoint_base.is_dir():
538
+ best_candidates = sorted(checkpoint_base.glob("*_best.pt"))
495
539
  candidates = sorted(checkpoint_base.glob("*.pt"))
496
- if not candidates:
540
+ if best_candidates:
541
+ model_file = best_candidates[-1]
542
+ elif candidates:
543
+ model_file = candidates[-1]
544
+ else:
497
545
  raise FileNotFoundError(
498
546
  f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
499
547
  )
500
- model_file = candidates[-1]
501
548
  config_dir_for_features = checkpoint_base
502
549
  else:
503
550
  model_file = (
@@ -564,11 +611,32 @@ def predict_model(predict_config_path: str) -> None:
564
611
  )
565
612
 
566
613
  log_cli_section("Model")
614
+ use_onnx = bool(predict_cfg.get("use_onnx")) or bool(predict_cfg.get("onnx_path"))
615
+ onnx_path = predict_cfg.get("onnx_path") or cfg.get("onnx_path")
616
+ if onnx_path:
617
+ onnx_path = resolve_path(onnx_path, config_dir)
618
+ if use_onnx and onnx_path is None:
619
+ search_dir = (
620
+ checkpoint_base if checkpoint_base.is_dir() else checkpoint_base.parent
621
+ )
622
+ best_candidates = sorted(search_dir.glob("*_best.onnx"))
623
+ if best_candidates:
624
+ onnx_path = best_candidates[-1]
625
+ else:
626
+ candidates = sorted(search_dir.glob("*.onnx"))
627
+ if not candidates:
628
+ raise FileNotFoundError(
629
+ f"[NextRec CLI Error]: Unable to find ONNX model in {search_dir}"
630
+ )
631
+ onnx_path = candidates[-1]
632
+
567
633
  log_kv_lines(
568
634
  [
569
635
  ("Model", model.__class__.__name__),
570
636
  ("Checkpoint", model_file),
571
637
  ("Device", predict_cfg.get("device", "cpu")),
638
+ ("Use ONNX", use_onnx),
639
+ ("ONNX path", onnx_path if use_onnx else "(disabled)"),
572
640
  ]
573
641
  )
574
642
 
@@ -582,7 +650,10 @@ def predict_model(predict_config_path: str) -> None:
582
650
  )
583
651
 
584
652
  data_path = resolve_path(predict_cfg["data_path"], config_dir)
585
- batch_size = predict_cfg.get("batch_size", 512)
653
+ streaming = bool(predict_cfg.get("streaming", True))
654
+ chunk_size = int(predict_cfg.get("chunk_size", 20000))
655
+ batch_size = int(predict_cfg.get("batch_size", 512))
656
+ effective_batch_size = chunk_size if streaming else batch_size
586
657
 
587
658
  log_cli_section("Data")
588
659
  log_kv_lines(
@@ -594,18 +665,18 @@ def predict_model(predict_config_path: str) -> None:
594
665
  "source_data_format", predict_cfg.get("data_format", "auto")
595
666
  ),
596
667
  ),
597
- ("Batch size", batch_size),
598
- ("Chunk size", predict_cfg.get("chunk_size", 20000)),
599
- ("Streaming", predict_cfg.get("streaming", True)),
668
+ ("Batch size", effective_batch_size),
669
+ ("Chunk size", chunk_size),
670
+ ("Streaming", streaming),
600
671
  ]
601
672
  )
602
673
  logger.info("")
603
674
  pred_loader = rec_dataloader.create_dataloader(
604
675
  data=str(data_path),
605
- batch_size=batch_size,
676
+ batch_size=1 if streaming else batch_size,
606
677
  shuffle=False,
607
- streaming=predict_cfg.get("streaming", True),
608
- chunk_size=predict_cfg.get("chunk_size", 20000),
678
+ streaming=streaming,
679
+ chunk_size=chunk_size,
609
680
  prefetch_factor=predict_cfg.get("prefetch_factor"),
610
681
  )
611
682
 
@@ -613,27 +684,43 @@ def predict_model(predict_config_path: str) -> None:
613
684
  "save_data_format", predict_cfg.get("save_format", "csv")
614
685
  )
615
686
  pred_name = predict_cfg.get("name", "pred")
616
-
617
- save_path = checkpoint_base / "predictions" / f"{pred_name}.{save_format}"
687
+ pred_name_path = Path(pred_name)
688
+ if pred_name_path.is_absolute():
689
+ save_path = pred_name_path
690
+ if save_path.suffix == "":
691
+ save_path = save_path.with_suffix(f".{save_format}")
692
+ else:
693
+ save_path = checkpoint_base / "predictions" / f"{pred_name}.{save_format}"
618
694
 
619
695
  start = time.time()
620
696
  logger.info("")
621
- result = model.predict(
622
- data=pred_loader,
623
- batch_size=batch_size,
624
- include_ids=bool(id_columns),
625
- return_dataframe=False,
626
- save_path=str(save_path),
627
- save_format=save_format,
628
- num_workers=predict_cfg.get("num_workers", 0),
629
- )
697
+ if use_onnx:
698
+ result = model.predict_onnx(
699
+ onnx_path=onnx_path,
700
+ data=pred_loader,
701
+ batch_size=effective_batch_size,
702
+ include_ids=bool(id_columns),
703
+ return_dataframe=False,
704
+ save_path=str(save_path),
705
+ save_format=save_format,
706
+ num_workers=predict_cfg.get("num_workers", 0),
707
+ )
708
+ else:
709
+ result = model.predict(
710
+ data=pred_loader,
711
+ batch_size=effective_batch_size,
712
+ include_ids=bool(id_columns),
713
+ return_dataframe=False,
714
+ save_path=str(save_path),
715
+ save_format=save_format,
716
+ num_workers=predict_cfg.get("num_workers", 0),
717
+ )
630
718
  duration = time.time() - start
631
719
  # When return_dataframe=False, result is the actual file path
632
- output_path = (
633
- result
634
- if isinstance(result, Path)
635
- else checkpoint_base / "predictions" / save_path
636
- )
720
+ if isinstance(result, (str, Path)):
721
+ output_path = Path(result)
722
+ else:
723
+ output_path = save_path
637
724
  logger.info(f"Prediction completed, results saved to: {output_path}")
638
725
  logger.info(f"Total time: {duration:.2f} seconds")
639
726
 
@@ -12,18 +12,21 @@ from typing import Any
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
+ import polars as pl
16
+
15
17
 
16
18
  from nextrec.utils.torch_utils import to_numpy
17
19
 
18
20
 
19
- def get_column_data(data: dict | pd.DataFrame, name: str):
21
+ def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
20
22
 
21
23
  if isinstance(data, dict):
22
24
  return data[name] if name in data else None
23
25
  elif isinstance(data, pd.DataFrame):
24
- if name not in data.columns:
25
- return None
26
26
  return data[name].values
27
+ elif isinstance(data, pl.DataFrame):
28
+ series = data.get_column(name)
29
+ return series.to_numpy()
27
30
  else:
28
31
  raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
29
32
 
@@ -33,6 +36,8 @@ def get_data_length(data: Any) -> int | None:
33
36
  return None
34
37
  if isinstance(data, pd.DataFrame):
35
38
  return len(data)
39
+ if isinstance(data, pl.DataFrame):
40
+ return data.height
36
41
  if isinstance(data, dict):
37
42
  if not data:
38
43
  return None
@@ -92,16 +97,6 @@ def split_dict_random(data_dict, test_size=0.2, random_state=None):
92
97
  return train_dict, test_dict
93
98
 
94
99
 
95
- def split_data(
96
- df: pd.DataFrame, test_size: float = 0.2
97
- ) -> tuple[pd.DataFrame, pd.DataFrame]:
98
-
99
- split_idx = int(len(df) * (1 - test_size))
100
- train_df = df.iloc[:split_idx].reset_index(drop=True)
101
- valid_df = df.iloc[split_idx:].reset_index(drop=True)
102
- return train_df, valid_df
103
-
104
-
105
100
  def build_eval_candidates(
106
101
  df_all: pd.DataFrame,
107
102
  user_col: str,