nextrec 0.4.34__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/cli.py CHANGED
@@ -422,6 +422,49 @@ def train_model(train_config_path: str) -> None:
422
422
  note=train_cfg.get("note"),
423
423
  )
424
424
 
425
+ export_cfg = train_cfg.get("export_onnx")
426
+ if export_cfg is None:
427
+ export_cfg = cfg.get("export_onnx")
428
+ export_enabled = False
429
+ export_options: dict[str, Any] = {}
430
+ if isinstance(export_cfg, bool):
431
+ export_enabled = export_cfg
432
+ elif isinstance(export_cfg, dict):
433
+ export_options = export_cfg
434
+ export_enabled = bool(export_cfg.get("enable", False))
435
+
436
+ if export_enabled:
437
+ log_cli_section("ONNX Export")
438
+ onnx_path = None
439
+ if export_options.get("path") or export_options.get("save_path"):
440
+ logger.warning(
441
+ "[NextRec CLI Warning] export_onnx.path/save_path is deprecated; "
442
+ "ONNX will be saved to best/checkpoint paths."
443
+ )
444
+ onnx_best_path = Path(model.best_path).with_suffix(".onnx")
445
+ onnx_ckpt_path = Path(model.checkpoint_path).with_suffix(".onnx")
446
+ onnx_batch_size = export_options.get("batch_size", 1)
447
+ onnx_opset = export_options.get("opset_version", 18)
448
+ log_kv_lines(
449
+ [
450
+ ("ONNX best path", onnx_best_path),
451
+ ("ONNX checkpoint path", onnx_ckpt_path),
452
+ ("Batch size", onnx_batch_size),
453
+ ("Opset", onnx_opset),
454
+ ("Dynamic batch", False),
455
+ ]
456
+ )
457
+ model.export_onnx(
458
+ save_path=onnx_best_path,
459
+ batch_size=onnx_batch_size,
460
+ opset_version=onnx_opset,
461
+ )
462
+ model.export_onnx(
463
+ save_path=onnx_ckpt_path,
464
+ batch_size=onnx_batch_size,
465
+ opset_version=onnx_opset,
466
+ )
467
+
425
468
 
426
469
  def predict_model(predict_config_path: str) -> None:
427
470
  """
@@ -492,12 +535,16 @@ def predict_model(predict_config_path: str) -> None:
492
535
  # Load checkpoint and ensure required parameters are passed
493
536
  checkpoint_base = Path(session_dir)
494
537
  if checkpoint_base.is_dir():
538
+ best_candidates = sorted(checkpoint_base.glob("*_best.pt"))
495
539
  candidates = sorted(checkpoint_base.glob("*.pt"))
496
- if not candidates:
540
+ if best_candidates:
541
+ model_file = best_candidates[-1]
542
+ elif candidates:
543
+ model_file = candidates[-1]
544
+ else:
497
545
  raise FileNotFoundError(
498
546
  f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
499
547
  )
500
- model_file = candidates[-1]
501
548
  config_dir_for_features = checkpoint_base
502
549
  else:
503
550
  model_file = (
@@ -564,11 +611,32 @@ def predict_model(predict_config_path: str) -> None:
564
611
  )
565
612
 
566
613
  log_cli_section("Model")
614
+ use_onnx = bool(predict_cfg.get("use_onnx")) or bool(predict_cfg.get("onnx_path"))
615
+ onnx_path = predict_cfg.get("onnx_path") or cfg.get("onnx_path")
616
+ if onnx_path:
617
+ onnx_path = resolve_path(onnx_path, config_dir)
618
+ if use_onnx and onnx_path is None:
619
+ search_dir = (
620
+ checkpoint_base if checkpoint_base.is_dir() else checkpoint_base.parent
621
+ )
622
+ best_candidates = sorted(search_dir.glob("*_best.onnx"))
623
+ if best_candidates:
624
+ onnx_path = best_candidates[-1]
625
+ else:
626
+ candidates = sorted(search_dir.glob("*.onnx"))
627
+ if not candidates:
628
+ raise FileNotFoundError(
629
+ f"[NextRec CLI Error]: Unable to find ONNX model in {search_dir}"
630
+ )
631
+ onnx_path = candidates[-1]
632
+
567
633
  log_kv_lines(
568
634
  [
569
635
  ("Model", model.__class__.__name__),
570
636
  ("Checkpoint", model_file),
571
637
  ("Device", predict_cfg.get("device", "cpu")),
638
+ ("Use ONNX", use_onnx),
639
+ ("ONNX path", onnx_path if use_onnx else "(disabled)"),
572
640
  ]
573
641
  )
574
642
 
@@ -582,7 +650,10 @@ def predict_model(predict_config_path: str) -> None:
582
650
  )
583
651
 
584
652
  data_path = resolve_path(predict_cfg["data_path"], config_dir)
585
- batch_size = predict_cfg.get("batch_size", 512)
653
+ streaming = bool(predict_cfg.get("streaming", True))
654
+ chunk_size = int(predict_cfg.get("chunk_size", 20000))
655
+ batch_size = int(predict_cfg.get("batch_size", 512))
656
+ effective_batch_size = chunk_size if streaming else batch_size
586
657
 
587
658
  log_cli_section("Data")
588
659
  log_kv_lines(
@@ -594,18 +665,18 @@ def predict_model(predict_config_path: str) -> None:
594
665
  "source_data_format", predict_cfg.get("data_format", "auto")
595
666
  ),
596
667
  ),
597
- ("Batch size", batch_size),
598
- ("Chunk size", predict_cfg.get("chunk_size", 20000)),
599
- ("Streaming", predict_cfg.get("streaming", True)),
668
+ ("Batch size", effective_batch_size),
669
+ ("Chunk size", chunk_size),
670
+ ("Streaming", streaming),
600
671
  ]
601
672
  )
602
673
  logger.info("")
603
674
  pred_loader = rec_dataloader.create_dataloader(
604
675
  data=str(data_path),
605
- batch_size=batch_size,
676
+ batch_size=1 if streaming else batch_size,
606
677
  shuffle=False,
607
- streaming=predict_cfg.get("streaming", True),
608
- chunk_size=predict_cfg.get("chunk_size", 20000),
678
+ streaming=streaming,
679
+ chunk_size=chunk_size,
609
680
  prefetch_factor=predict_cfg.get("prefetch_factor"),
610
681
  )
611
682
 
@@ -623,15 +694,27 @@ def predict_model(predict_config_path: str) -> None:
623
694
 
624
695
  start = time.time()
625
696
  logger.info("")
626
- result = model.predict(
627
- data=pred_loader,
628
- batch_size=batch_size,
629
- include_ids=bool(id_columns),
630
- return_dataframe=False,
631
- save_path=str(save_path),
632
- save_format=save_format,
633
- num_workers=predict_cfg.get("num_workers", 0),
634
- )
697
+ if use_onnx:
698
+ result = model.predict_onnx(
699
+ onnx_path=onnx_path,
700
+ data=pred_loader,
701
+ batch_size=effective_batch_size,
702
+ include_ids=bool(id_columns),
703
+ return_dataframe=False,
704
+ save_path=str(save_path),
705
+ save_format=save_format,
706
+ num_workers=predict_cfg.get("num_workers", 0),
707
+ )
708
+ else:
709
+ result = model.predict(
710
+ data=pred_loader,
711
+ batch_size=effective_batch_size,
712
+ include_ids=bool(id_columns),
713
+ return_dataframe=False,
714
+ save_path=str(save_path),
715
+ save_format=save_format,
716
+ num_workers=predict_cfg.get("num_workers", 0),
717
+ )
635
718
  duration = time.time() - start
636
719
  # When return_dataframe=False, result is the actual file path
637
720
  if isinstance(result, (str, Path)):
@@ -12,18 +12,21 @@ from typing import Any
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
+ import polars as pl
16
+
15
17
 
16
18
  from nextrec.utils.torch_utils import to_numpy
17
19
 
18
20
 
19
- def get_column_data(data: dict | pd.DataFrame, name: str):
21
+ def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
20
22
 
21
23
  if isinstance(data, dict):
22
24
  return data[name] if name in data else None
23
25
  elif isinstance(data, pd.DataFrame):
24
- if name not in data.columns:
25
- return None
26
26
  return data[name].values
27
+ elif isinstance(data, pl.DataFrame):
28
+ series = data.get_column(name)
29
+ return series.to_numpy()
27
30
  else:
28
31
  raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
29
32
 
@@ -33,6 +36,8 @@ def get_data_length(data: Any) -> int | None:
33
36
  return None
34
37
  if isinstance(data, pd.DataFrame):
35
38
  return len(data)
39
+ if isinstance(data, pl.DataFrame):
40
+ return data.height
36
41
  if isinstance(data, dict):
37
42
  if not data:
38
43
  return None
@@ -92,16 +97,6 @@ def split_dict_random(data_dict, test_size=0.2, random_state=None):
92
97
  return train_dict, test_dict
93
98
 
94
99
 
95
- def split_data(
96
- df: pd.DataFrame, test_size: float = 0.2
97
- ) -> tuple[pd.DataFrame, pd.DataFrame]:
98
-
99
- split_idx = int(len(df) * (1 - test_size))
100
- train_df = df.iloc[:split_idx].reset_index(drop=True)
101
- valid_df = df.iloc[split_idx:].reset_index(drop=True)
102
- return train_df, valid_df
103
-
104
-
105
100
  def build_eval_candidates(
106
101
  df_all: pd.DataFrame,
107
102
  user_col: str,