nextrec 0.4.34__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/cli.py CHANGED
@@ -14,7 +14,7 @@ Examples:
14
14
  nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
15
15
 
16
16
  Date: create on 06/12/2025
17
- Checkpoint: edit on 18/12/2025
17
+ Checkpoint: edit on 29/01/2026
18
18
  Author: Yang Zhou, zyaztec@gmail.com
19
19
  """
20
20
 
@@ -251,7 +251,6 @@ def train_model(train_config_path: str) -> None:
251
251
  processor.fit_from_files(
252
252
  file_paths=streaming_train_files or file_paths,
253
253
  file_type=file_type,
254
- chunk_size=dataloader_chunk_size,
255
254
  )
256
255
  processed = None
257
256
  df = None # type: ignore[assignment]
@@ -422,6 +421,49 @@ def train_model(train_config_path: str) -> None:
422
421
  note=train_cfg.get("note"),
423
422
  )
424
423
 
424
+ export_cfg = train_cfg.get("export_onnx")
425
+ if export_cfg is None:
426
+ export_cfg = cfg.get("export_onnx")
427
+ export_enabled = False
428
+ export_options: dict[str, Any] = {}
429
+ if isinstance(export_cfg, bool):
430
+ export_enabled = export_cfg
431
+ elif isinstance(export_cfg, dict):
432
+ export_options = export_cfg
433
+ export_enabled = bool(export_cfg.get("enable", False))
434
+
435
+ if export_enabled:
436
+ log_cli_section("ONNX Export")
437
+ onnx_path = None
438
+ if export_options.get("path") or export_options.get("save_path"):
439
+ logger.warning(
440
+ "[NextRec CLI Warning] export_onnx.path/save_path is deprecated; "
441
+ "ONNX will be saved to best/checkpoint paths."
442
+ )
443
+ onnx_best_path = Path(model.best_path).with_suffix(".onnx")
444
+ onnx_ckpt_path = Path(model.checkpoint_path).with_suffix(".onnx")
445
+ onnx_batch_size = export_options.get("batch_size", 1)
446
+ onnx_opset = export_options.get("opset_version", 18)
447
+ log_kv_lines(
448
+ [
449
+ ("ONNX best path", onnx_best_path),
450
+ ("ONNX checkpoint path", onnx_ckpt_path),
451
+ ("Batch size", onnx_batch_size),
452
+ ("Opset", onnx_opset),
453
+ ("Dynamic batch", False),
454
+ ]
455
+ )
456
+ model.export_onnx(
457
+ save_path=onnx_best_path,
458
+ batch_size=onnx_batch_size,
459
+ opset_version=onnx_opset,
460
+ )
461
+ model.export_onnx(
462
+ save_path=onnx_ckpt_path,
463
+ batch_size=onnx_batch_size,
464
+ opset_version=onnx_opset,
465
+ )
466
+
425
467
 
426
468
  def predict_model(predict_config_path: str) -> None:
427
469
  """
@@ -492,12 +534,16 @@ def predict_model(predict_config_path: str) -> None:
492
534
  # Load checkpoint and ensure required parameters are passed
493
535
  checkpoint_base = Path(session_dir)
494
536
  if checkpoint_base.is_dir():
537
+ best_candidates = sorted(checkpoint_base.glob("*_best.pt"))
495
538
  candidates = sorted(checkpoint_base.glob("*.pt"))
496
- if not candidates:
539
+ if best_candidates:
540
+ model_file = best_candidates[-1]
541
+ elif candidates:
542
+ model_file = candidates[-1]
543
+ else:
497
544
  raise FileNotFoundError(
498
545
  f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
499
546
  )
500
- model_file = candidates[-1]
501
547
  config_dir_for_features = checkpoint_base
502
548
  else:
503
549
  model_file = (
@@ -564,11 +610,32 @@ def predict_model(predict_config_path: str) -> None:
564
610
  )
565
611
 
566
612
  log_cli_section("Model")
613
+ use_onnx = bool(predict_cfg.get("use_onnx")) or bool(predict_cfg.get("onnx_path"))
614
+ onnx_path = predict_cfg.get("onnx_path") or cfg.get("onnx_path")
615
+ if onnx_path:
616
+ onnx_path = resolve_path(onnx_path, config_dir)
617
+ if use_onnx and onnx_path is None:
618
+ search_dir = (
619
+ checkpoint_base if checkpoint_base.is_dir() else checkpoint_base.parent
620
+ )
621
+ best_candidates = sorted(search_dir.glob("*_best.onnx"))
622
+ if best_candidates:
623
+ onnx_path = best_candidates[-1]
624
+ else:
625
+ candidates = sorted(search_dir.glob("*.onnx"))
626
+ if not candidates:
627
+ raise FileNotFoundError(
628
+ f"[NextRec CLI Error]: Unable to find ONNX model in {search_dir}"
629
+ )
630
+ onnx_path = candidates[-1]
631
+
567
632
  log_kv_lines(
568
633
  [
569
634
  ("Model", model.__class__.__name__),
570
635
  ("Checkpoint", model_file),
571
636
  ("Device", predict_cfg.get("device", "cpu")),
637
+ ("Use ONNX", use_onnx),
638
+ ("ONNX path", onnx_path if use_onnx else "(disabled)"),
572
639
  ]
573
640
  )
574
641
 
@@ -582,7 +649,10 @@ def predict_model(predict_config_path: str) -> None:
582
649
  )
583
650
 
584
651
  data_path = resolve_path(predict_cfg["data_path"], config_dir)
585
- batch_size = predict_cfg.get("batch_size", 512)
652
+ streaming = bool(predict_cfg.get("streaming", True))
653
+ chunk_size = int(predict_cfg.get("chunk_size", 20000))
654
+ batch_size = int(predict_cfg.get("batch_size", 512))
655
+ effective_batch_size = chunk_size if streaming else batch_size
586
656
 
587
657
  log_cli_section("Data")
588
658
  log_kv_lines(
@@ -594,18 +664,18 @@ def predict_model(predict_config_path: str) -> None:
594
664
  "source_data_format", predict_cfg.get("data_format", "auto")
595
665
  ),
596
666
  ),
597
- ("Batch size", batch_size),
598
- ("Chunk size", predict_cfg.get("chunk_size", 20000)),
599
- ("Streaming", predict_cfg.get("streaming", True)),
667
+ ("Batch size", effective_batch_size),
668
+ ("Chunk size", chunk_size),
669
+ ("Streaming", streaming),
600
670
  ]
601
671
  )
602
672
  logger.info("")
603
673
  pred_loader = rec_dataloader.create_dataloader(
604
674
  data=str(data_path),
605
- batch_size=batch_size,
675
+ batch_size=1 if streaming else batch_size,
606
676
  shuffle=False,
607
- streaming=predict_cfg.get("streaming", True),
608
- chunk_size=predict_cfg.get("chunk_size", 20000),
677
+ streaming=streaming,
678
+ chunk_size=chunk_size,
609
679
  prefetch_factor=predict_cfg.get("prefetch_factor"),
610
680
  )
611
681
 
@@ -623,15 +693,27 @@ def predict_model(predict_config_path: str) -> None:
623
693
 
624
694
  start = time.time()
625
695
  logger.info("")
626
- result = model.predict(
627
- data=pred_loader,
628
- batch_size=batch_size,
629
- include_ids=bool(id_columns),
630
- return_dataframe=False,
631
- save_path=str(save_path),
632
- save_format=save_format,
633
- num_workers=predict_cfg.get("num_workers", 0),
634
- )
696
+ if use_onnx:
697
+ result = model.predict_onnx(
698
+ onnx_path=onnx_path,
699
+ data=pred_loader,
700
+ batch_size=effective_batch_size,
701
+ include_ids=bool(id_columns),
702
+ return_dataframe=False,
703
+ save_path=str(save_path),
704
+ save_format=save_format,
705
+ num_workers=predict_cfg.get("num_workers", 0),
706
+ )
707
+ else:
708
+ result = model.predict(
709
+ data=pred_loader,
710
+ batch_size=effective_batch_size,
711
+ include_ids=bool(id_columns),
712
+ return_dataframe=False,
713
+ save_path=str(save_path),
714
+ save_format=save_format,
715
+ num_workers=predict_cfg.get("num_workers", 0),
716
+ )
635
717
  duration = time.time() - start
636
718
  # When return_dataframe=False, result is the actual file path
637
719
  if isinstance(result, (str, Path)):
@@ -12,18 +12,21 @@ from typing import Any
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
+ import polars as pl
16
+
15
17
 
16
18
  from nextrec.utils.torch_utils import to_numpy
17
19
 
18
20
 
19
- def get_column_data(data: dict | pd.DataFrame, name: str):
21
+ def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
20
22
 
21
23
  if isinstance(data, dict):
22
24
  return data[name] if name in data else None
23
25
  elif isinstance(data, pd.DataFrame):
24
- if name not in data.columns:
25
- return None
26
26
  return data[name].values
27
+ elif isinstance(data, pl.DataFrame):
28
+ series = data.get_column(name)
29
+ return series.to_numpy()
27
30
  else:
28
31
  raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
29
32
 
@@ -33,6 +36,8 @@ def get_data_length(data: Any) -> int | None:
33
36
  return None
34
37
  if isinstance(data, pd.DataFrame):
35
38
  return len(data)
39
+ if isinstance(data, pl.DataFrame):
40
+ return data.height
36
41
  if isinstance(data, dict):
37
42
  if not data:
38
43
  return None
@@ -92,16 +97,6 @@ def split_dict_random(data_dict, test_size=0.2, random_state=None):
92
97
  return train_dict, test_dict
93
98
 
94
99
 
95
- def split_data(
96
- df: pd.DataFrame, test_size: float = 0.2
97
- ) -> tuple[pd.DataFrame, pd.DataFrame]:
98
-
99
- split_idx = int(len(df) * (1 - test_size))
100
- train_df = df.iloc[:split_idx].reset_index(drop=True)
101
- valid_df = df.iloc[split_idx:].reset_index(drop=True)
102
- return train_df, valid_df
103
-
104
-
105
100
  def build_eval_candidates(
106
101
  df_all: pd.DataFrame,
107
102
  user_col: str,