nextrec 0.4.34__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +7 -13
- nextrec/basic/layers.py +28 -94
- nextrec/basic/model.py +512 -4
- nextrec/cli.py +102 -20
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -846
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/utils/onnx_utils.py +252 -0
- nextrec/utils/torch_utils.py +6 -1
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/METADATA +10 -4
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/RECORD +19 -19
- nextrec/models/multi_task/[pre]star.py +0 -192
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/WHEEL +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/licenses/LICENSE +0 -0
nextrec/cli.py
CHANGED
|
@@ -14,7 +14,7 @@ Examples:
|
|
|
14
14
|
nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
|
|
15
15
|
|
|
16
16
|
Date: create on 06/12/2025
|
|
17
|
-
Checkpoint: edit on
|
|
17
|
+
Checkpoint: edit on 29/01/2026
|
|
18
18
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
19
19
|
"""
|
|
20
20
|
|
|
@@ -251,7 +251,6 @@ def train_model(train_config_path: str) -> None:
|
|
|
251
251
|
processor.fit_from_files(
|
|
252
252
|
file_paths=streaming_train_files or file_paths,
|
|
253
253
|
file_type=file_type,
|
|
254
|
-
chunk_size=dataloader_chunk_size,
|
|
255
254
|
)
|
|
256
255
|
processed = None
|
|
257
256
|
df = None # type: ignore[assignment]
|
|
@@ -422,6 +421,49 @@ def train_model(train_config_path: str) -> None:
|
|
|
422
421
|
note=train_cfg.get("note"),
|
|
423
422
|
)
|
|
424
423
|
|
|
424
|
+
export_cfg = train_cfg.get("export_onnx")
|
|
425
|
+
if export_cfg is None:
|
|
426
|
+
export_cfg = cfg.get("export_onnx")
|
|
427
|
+
export_enabled = False
|
|
428
|
+
export_options: dict[str, Any] = {}
|
|
429
|
+
if isinstance(export_cfg, bool):
|
|
430
|
+
export_enabled = export_cfg
|
|
431
|
+
elif isinstance(export_cfg, dict):
|
|
432
|
+
export_options = export_cfg
|
|
433
|
+
export_enabled = bool(export_cfg.get("enable", False))
|
|
434
|
+
|
|
435
|
+
if export_enabled:
|
|
436
|
+
log_cli_section("ONNX Export")
|
|
437
|
+
onnx_path = None
|
|
438
|
+
if export_options.get("path") or export_options.get("save_path"):
|
|
439
|
+
logger.warning(
|
|
440
|
+
"[NextRec CLI Warning] export_onnx.path/save_path is deprecated; "
|
|
441
|
+
"ONNX will be saved to best/checkpoint paths."
|
|
442
|
+
)
|
|
443
|
+
onnx_best_path = Path(model.best_path).with_suffix(".onnx")
|
|
444
|
+
onnx_ckpt_path = Path(model.checkpoint_path).with_suffix(".onnx")
|
|
445
|
+
onnx_batch_size = export_options.get("batch_size", 1)
|
|
446
|
+
onnx_opset = export_options.get("opset_version", 18)
|
|
447
|
+
log_kv_lines(
|
|
448
|
+
[
|
|
449
|
+
("ONNX best path", onnx_best_path),
|
|
450
|
+
("ONNX checkpoint path", onnx_ckpt_path),
|
|
451
|
+
("Batch size", onnx_batch_size),
|
|
452
|
+
("Opset", onnx_opset),
|
|
453
|
+
("Dynamic batch", False),
|
|
454
|
+
]
|
|
455
|
+
)
|
|
456
|
+
model.export_onnx(
|
|
457
|
+
save_path=onnx_best_path,
|
|
458
|
+
batch_size=onnx_batch_size,
|
|
459
|
+
opset_version=onnx_opset,
|
|
460
|
+
)
|
|
461
|
+
model.export_onnx(
|
|
462
|
+
save_path=onnx_ckpt_path,
|
|
463
|
+
batch_size=onnx_batch_size,
|
|
464
|
+
opset_version=onnx_opset,
|
|
465
|
+
)
|
|
466
|
+
|
|
425
467
|
|
|
426
468
|
def predict_model(predict_config_path: str) -> None:
|
|
427
469
|
"""
|
|
@@ -492,12 +534,16 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
492
534
|
# Load checkpoint and ensure required parameters are passed
|
|
493
535
|
checkpoint_base = Path(session_dir)
|
|
494
536
|
if checkpoint_base.is_dir():
|
|
537
|
+
best_candidates = sorted(checkpoint_base.glob("*_best.pt"))
|
|
495
538
|
candidates = sorted(checkpoint_base.glob("*.pt"))
|
|
496
|
-
if
|
|
539
|
+
if best_candidates:
|
|
540
|
+
model_file = best_candidates[-1]
|
|
541
|
+
elif candidates:
|
|
542
|
+
model_file = candidates[-1]
|
|
543
|
+
else:
|
|
497
544
|
raise FileNotFoundError(
|
|
498
545
|
f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
|
|
499
546
|
)
|
|
500
|
-
model_file = candidates[-1]
|
|
501
547
|
config_dir_for_features = checkpoint_base
|
|
502
548
|
else:
|
|
503
549
|
model_file = (
|
|
@@ -564,11 +610,32 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
564
610
|
)
|
|
565
611
|
|
|
566
612
|
log_cli_section("Model")
|
|
613
|
+
use_onnx = bool(predict_cfg.get("use_onnx")) or bool(predict_cfg.get("onnx_path"))
|
|
614
|
+
onnx_path = predict_cfg.get("onnx_path") or cfg.get("onnx_path")
|
|
615
|
+
if onnx_path:
|
|
616
|
+
onnx_path = resolve_path(onnx_path, config_dir)
|
|
617
|
+
if use_onnx and onnx_path is None:
|
|
618
|
+
search_dir = (
|
|
619
|
+
checkpoint_base if checkpoint_base.is_dir() else checkpoint_base.parent
|
|
620
|
+
)
|
|
621
|
+
best_candidates = sorted(search_dir.glob("*_best.onnx"))
|
|
622
|
+
if best_candidates:
|
|
623
|
+
onnx_path = best_candidates[-1]
|
|
624
|
+
else:
|
|
625
|
+
candidates = sorted(search_dir.glob("*.onnx"))
|
|
626
|
+
if not candidates:
|
|
627
|
+
raise FileNotFoundError(
|
|
628
|
+
f"[NextRec CLI Error]: Unable to find ONNX model in {search_dir}"
|
|
629
|
+
)
|
|
630
|
+
onnx_path = candidates[-1]
|
|
631
|
+
|
|
567
632
|
log_kv_lines(
|
|
568
633
|
[
|
|
569
634
|
("Model", model.__class__.__name__),
|
|
570
635
|
("Checkpoint", model_file),
|
|
571
636
|
("Device", predict_cfg.get("device", "cpu")),
|
|
637
|
+
("Use ONNX", use_onnx),
|
|
638
|
+
("ONNX path", onnx_path if use_onnx else "(disabled)"),
|
|
572
639
|
]
|
|
573
640
|
)
|
|
574
641
|
|
|
@@ -582,7 +649,10 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
582
649
|
)
|
|
583
650
|
|
|
584
651
|
data_path = resolve_path(predict_cfg["data_path"], config_dir)
|
|
585
|
-
|
|
652
|
+
streaming = bool(predict_cfg.get("streaming", True))
|
|
653
|
+
chunk_size = int(predict_cfg.get("chunk_size", 20000))
|
|
654
|
+
batch_size = int(predict_cfg.get("batch_size", 512))
|
|
655
|
+
effective_batch_size = chunk_size if streaming else batch_size
|
|
586
656
|
|
|
587
657
|
log_cli_section("Data")
|
|
588
658
|
log_kv_lines(
|
|
@@ -594,18 +664,18 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
594
664
|
"source_data_format", predict_cfg.get("data_format", "auto")
|
|
595
665
|
),
|
|
596
666
|
),
|
|
597
|
-
("Batch size",
|
|
598
|
-
("Chunk size",
|
|
599
|
-
("Streaming",
|
|
667
|
+
("Batch size", effective_batch_size),
|
|
668
|
+
("Chunk size", chunk_size),
|
|
669
|
+
("Streaming", streaming),
|
|
600
670
|
]
|
|
601
671
|
)
|
|
602
672
|
logger.info("")
|
|
603
673
|
pred_loader = rec_dataloader.create_dataloader(
|
|
604
674
|
data=str(data_path),
|
|
605
|
-
batch_size=batch_size,
|
|
675
|
+
batch_size=1 if streaming else batch_size,
|
|
606
676
|
shuffle=False,
|
|
607
|
-
streaming=
|
|
608
|
-
chunk_size=
|
|
677
|
+
streaming=streaming,
|
|
678
|
+
chunk_size=chunk_size,
|
|
609
679
|
prefetch_factor=predict_cfg.get("prefetch_factor"),
|
|
610
680
|
)
|
|
611
681
|
|
|
@@ -623,15 +693,27 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
623
693
|
|
|
624
694
|
start = time.time()
|
|
625
695
|
logger.info("")
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
696
|
+
if use_onnx:
|
|
697
|
+
result = model.predict_onnx(
|
|
698
|
+
onnx_path=onnx_path,
|
|
699
|
+
data=pred_loader,
|
|
700
|
+
batch_size=effective_batch_size,
|
|
701
|
+
include_ids=bool(id_columns),
|
|
702
|
+
return_dataframe=False,
|
|
703
|
+
save_path=str(save_path),
|
|
704
|
+
save_format=save_format,
|
|
705
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
706
|
+
)
|
|
707
|
+
else:
|
|
708
|
+
result = model.predict(
|
|
709
|
+
data=pred_loader,
|
|
710
|
+
batch_size=effective_batch_size,
|
|
711
|
+
include_ids=bool(id_columns),
|
|
712
|
+
return_dataframe=False,
|
|
713
|
+
save_path=str(save_path),
|
|
714
|
+
save_format=save_format,
|
|
715
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
716
|
+
)
|
|
635
717
|
duration = time.time() - start
|
|
636
718
|
# When return_dataframe=False, result is the actual file path
|
|
637
719
|
if isinstance(result, (str, Path)):
|
nextrec/data/data_processing.py
CHANGED
|
@@ -12,18 +12,21 @@ from typing import Any
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import torch
|
|
15
|
+
import polars as pl
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
from nextrec.utils.torch_utils import to_numpy
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
21
|
+
def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
|
|
20
22
|
|
|
21
23
|
if isinstance(data, dict):
|
|
22
24
|
return data[name] if name in data else None
|
|
23
25
|
elif isinstance(data, pd.DataFrame):
|
|
24
|
-
if name not in data.columns:
|
|
25
|
-
return None
|
|
26
26
|
return data[name].values
|
|
27
|
+
elif isinstance(data, pl.DataFrame):
|
|
28
|
+
series = data.get_column(name)
|
|
29
|
+
return series.to_numpy()
|
|
27
30
|
else:
|
|
28
31
|
raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
|
|
29
32
|
|
|
@@ -33,6 +36,8 @@ def get_data_length(data: Any) -> int | None:
|
|
|
33
36
|
return None
|
|
34
37
|
if isinstance(data, pd.DataFrame):
|
|
35
38
|
return len(data)
|
|
39
|
+
if isinstance(data, pl.DataFrame):
|
|
40
|
+
return data.height
|
|
36
41
|
if isinstance(data, dict):
|
|
37
42
|
if not data:
|
|
38
43
|
return None
|
|
@@ -92,16 +97,6 @@ def split_dict_random(data_dict, test_size=0.2, random_state=None):
|
|
|
92
97
|
return train_dict, test_dict
|
|
93
98
|
|
|
94
99
|
|
|
95
|
-
def split_data(
|
|
96
|
-
df: pd.DataFrame, test_size: float = 0.2
|
|
97
|
-
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
98
|
-
|
|
99
|
-
split_idx = int(len(df) * (1 - test_size))
|
|
100
|
-
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
101
|
-
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
102
|
-
return train_df, valid_df
|
|
103
|
-
|
|
104
|
-
|
|
105
100
|
def build_eval_candidates(
|
|
106
101
|
df_all: pd.DataFrame,
|
|
107
102
|
user_col: str,
|