nextrec 0.4.34__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +7 -13
- nextrec/basic/layers.py +28 -94
- nextrec/basic/model.py +512 -4
- nextrec/cli.py +101 -18
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -846
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/utils/onnx_utils.py +252 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/RECORD +18 -18
- nextrec/models/multi_task/[pre]star.py +0 -192
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/cli.py
CHANGED
|
@@ -422,6 +422,49 @@ def train_model(train_config_path: str) -> None:
|
|
|
422
422
|
note=train_cfg.get("note"),
|
|
423
423
|
)
|
|
424
424
|
|
|
425
|
+
export_cfg = train_cfg.get("export_onnx")
|
|
426
|
+
if export_cfg is None:
|
|
427
|
+
export_cfg = cfg.get("export_onnx")
|
|
428
|
+
export_enabled = False
|
|
429
|
+
export_options: dict[str, Any] = {}
|
|
430
|
+
if isinstance(export_cfg, bool):
|
|
431
|
+
export_enabled = export_cfg
|
|
432
|
+
elif isinstance(export_cfg, dict):
|
|
433
|
+
export_options = export_cfg
|
|
434
|
+
export_enabled = bool(export_cfg.get("enable", False))
|
|
435
|
+
|
|
436
|
+
if export_enabled:
|
|
437
|
+
log_cli_section("ONNX Export")
|
|
438
|
+
onnx_path = None
|
|
439
|
+
if export_options.get("path") or export_options.get("save_path"):
|
|
440
|
+
logger.warning(
|
|
441
|
+
"[NextRec CLI Warning] export_onnx.path/save_path is deprecated; "
|
|
442
|
+
"ONNX will be saved to best/checkpoint paths."
|
|
443
|
+
)
|
|
444
|
+
onnx_best_path = Path(model.best_path).with_suffix(".onnx")
|
|
445
|
+
onnx_ckpt_path = Path(model.checkpoint_path).with_suffix(".onnx")
|
|
446
|
+
onnx_batch_size = export_options.get("batch_size", 1)
|
|
447
|
+
onnx_opset = export_options.get("opset_version", 18)
|
|
448
|
+
log_kv_lines(
|
|
449
|
+
[
|
|
450
|
+
("ONNX best path", onnx_best_path),
|
|
451
|
+
("ONNX checkpoint path", onnx_ckpt_path),
|
|
452
|
+
("Batch size", onnx_batch_size),
|
|
453
|
+
("Opset", onnx_opset),
|
|
454
|
+
("Dynamic batch", False),
|
|
455
|
+
]
|
|
456
|
+
)
|
|
457
|
+
model.export_onnx(
|
|
458
|
+
save_path=onnx_best_path,
|
|
459
|
+
batch_size=onnx_batch_size,
|
|
460
|
+
opset_version=onnx_opset,
|
|
461
|
+
)
|
|
462
|
+
model.export_onnx(
|
|
463
|
+
save_path=onnx_ckpt_path,
|
|
464
|
+
batch_size=onnx_batch_size,
|
|
465
|
+
opset_version=onnx_opset,
|
|
466
|
+
)
|
|
467
|
+
|
|
425
468
|
|
|
426
469
|
def predict_model(predict_config_path: str) -> None:
|
|
427
470
|
"""
|
|
@@ -492,12 +535,16 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
492
535
|
# Load checkpoint and ensure required parameters are passed
|
|
493
536
|
checkpoint_base = Path(session_dir)
|
|
494
537
|
if checkpoint_base.is_dir():
|
|
538
|
+
best_candidates = sorted(checkpoint_base.glob("*_best.pt"))
|
|
495
539
|
candidates = sorted(checkpoint_base.glob("*.pt"))
|
|
496
|
-
if
|
|
540
|
+
if best_candidates:
|
|
541
|
+
model_file = best_candidates[-1]
|
|
542
|
+
elif candidates:
|
|
543
|
+
model_file = candidates[-1]
|
|
544
|
+
else:
|
|
497
545
|
raise FileNotFoundError(
|
|
498
546
|
f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
|
|
499
547
|
)
|
|
500
|
-
model_file = candidates[-1]
|
|
501
548
|
config_dir_for_features = checkpoint_base
|
|
502
549
|
else:
|
|
503
550
|
model_file = (
|
|
@@ -564,11 +611,32 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
564
611
|
)
|
|
565
612
|
|
|
566
613
|
log_cli_section("Model")
|
|
614
|
+
use_onnx = bool(predict_cfg.get("use_onnx")) or bool(predict_cfg.get("onnx_path"))
|
|
615
|
+
onnx_path = predict_cfg.get("onnx_path") or cfg.get("onnx_path")
|
|
616
|
+
if onnx_path:
|
|
617
|
+
onnx_path = resolve_path(onnx_path, config_dir)
|
|
618
|
+
if use_onnx and onnx_path is None:
|
|
619
|
+
search_dir = (
|
|
620
|
+
checkpoint_base if checkpoint_base.is_dir() else checkpoint_base.parent
|
|
621
|
+
)
|
|
622
|
+
best_candidates = sorted(search_dir.glob("*_best.onnx"))
|
|
623
|
+
if best_candidates:
|
|
624
|
+
onnx_path = best_candidates[-1]
|
|
625
|
+
else:
|
|
626
|
+
candidates = sorted(search_dir.glob("*.onnx"))
|
|
627
|
+
if not candidates:
|
|
628
|
+
raise FileNotFoundError(
|
|
629
|
+
f"[NextRec CLI Error]: Unable to find ONNX model in {search_dir}"
|
|
630
|
+
)
|
|
631
|
+
onnx_path = candidates[-1]
|
|
632
|
+
|
|
567
633
|
log_kv_lines(
|
|
568
634
|
[
|
|
569
635
|
("Model", model.__class__.__name__),
|
|
570
636
|
("Checkpoint", model_file),
|
|
571
637
|
("Device", predict_cfg.get("device", "cpu")),
|
|
638
|
+
("Use ONNX", use_onnx),
|
|
639
|
+
("ONNX path", onnx_path if use_onnx else "(disabled)"),
|
|
572
640
|
]
|
|
573
641
|
)
|
|
574
642
|
|
|
@@ -582,7 +650,10 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
582
650
|
)
|
|
583
651
|
|
|
584
652
|
data_path = resolve_path(predict_cfg["data_path"], config_dir)
|
|
585
|
-
|
|
653
|
+
streaming = bool(predict_cfg.get("streaming", True))
|
|
654
|
+
chunk_size = int(predict_cfg.get("chunk_size", 20000))
|
|
655
|
+
batch_size = int(predict_cfg.get("batch_size", 512))
|
|
656
|
+
effective_batch_size = chunk_size if streaming else batch_size
|
|
586
657
|
|
|
587
658
|
log_cli_section("Data")
|
|
588
659
|
log_kv_lines(
|
|
@@ -594,18 +665,18 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
594
665
|
"source_data_format", predict_cfg.get("data_format", "auto")
|
|
595
666
|
),
|
|
596
667
|
),
|
|
597
|
-
("Batch size",
|
|
598
|
-
("Chunk size",
|
|
599
|
-
("Streaming",
|
|
668
|
+
("Batch size", effective_batch_size),
|
|
669
|
+
("Chunk size", chunk_size),
|
|
670
|
+
("Streaming", streaming),
|
|
600
671
|
]
|
|
601
672
|
)
|
|
602
673
|
logger.info("")
|
|
603
674
|
pred_loader = rec_dataloader.create_dataloader(
|
|
604
675
|
data=str(data_path),
|
|
605
|
-
batch_size=batch_size,
|
|
676
|
+
batch_size=1 if streaming else batch_size,
|
|
606
677
|
shuffle=False,
|
|
607
|
-
streaming=
|
|
608
|
-
chunk_size=
|
|
678
|
+
streaming=streaming,
|
|
679
|
+
chunk_size=chunk_size,
|
|
609
680
|
prefetch_factor=predict_cfg.get("prefetch_factor"),
|
|
610
681
|
)
|
|
611
682
|
|
|
@@ -623,15 +694,27 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
623
694
|
|
|
624
695
|
start = time.time()
|
|
625
696
|
logger.info("")
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
697
|
+
if use_onnx:
|
|
698
|
+
result = model.predict_onnx(
|
|
699
|
+
onnx_path=onnx_path,
|
|
700
|
+
data=pred_loader,
|
|
701
|
+
batch_size=effective_batch_size,
|
|
702
|
+
include_ids=bool(id_columns),
|
|
703
|
+
return_dataframe=False,
|
|
704
|
+
save_path=str(save_path),
|
|
705
|
+
save_format=save_format,
|
|
706
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
707
|
+
)
|
|
708
|
+
else:
|
|
709
|
+
result = model.predict(
|
|
710
|
+
data=pred_loader,
|
|
711
|
+
batch_size=effective_batch_size,
|
|
712
|
+
include_ids=bool(id_columns),
|
|
713
|
+
return_dataframe=False,
|
|
714
|
+
save_path=str(save_path),
|
|
715
|
+
save_format=save_format,
|
|
716
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
717
|
+
)
|
|
635
718
|
duration = time.time() - start
|
|
636
719
|
# When return_dataframe=False, result is the actual file path
|
|
637
720
|
if isinstance(result, (str, Path)):
|
nextrec/data/data_processing.py
CHANGED
|
@@ -12,18 +12,21 @@ from typing import Any
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import torch
|
|
15
|
+
import polars as pl
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
from nextrec.utils.torch_utils import to_numpy
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
21
|
+
def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
|
|
20
22
|
|
|
21
23
|
if isinstance(data, dict):
|
|
22
24
|
return data[name] if name in data else None
|
|
23
25
|
elif isinstance(data, pd.DataFrame):
|
|
24
|
-
if name not in data.columns:
|
|
25
|
-
return None
|
|
26
26
|
return data[name].values
|
|
27
|
+
elif isinstance(data, pl.DataFrame):
|
|
28
|
+
series = data.get_column(name)
|
|
29
|
+
return series.to_numpy()
|
|
27
30
|
else:
|
|
28
31
|
raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
|
|
29
32
|
|
|
@@ -33,6 +36,8 @@ def get_data_length(data: Any) -> int | None:
|
|
|
33
36
|
return None
|
|
34
37
|
if isinstance(data, pd.DataFrame):
|
|
35
38
|
return len(data)
|
|
39
|
+
if isinstance(data, pl.DataFrame):
|
|
40
|
+
return data.height
|
|
36
41
|
if isinstance(data, dict):
|
|
37
42
|
if not data:
|
|
38
43
|
return None
|
|
@@ -92,16 +97,6 @@ def split_dict_random(data_dict, test_size=0.2, random_state=None):
|
|
|
92
97
|
return train_dict, test_dict
|
|
93
98
|
|
|
94
99
|
|
|
95
|
-
def split_data(
|
|
96
|
-
df: pd.DataFrame, test_size: float = 0.2
|
|
97
|
-
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
98
|
-
|
|
99
|
-
split_idx = int(len(df) * (1 - test_size))
|
|
100
|
-
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
101
|
-
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
102
|
-
return train_df, valid_df
|
|
103
|
-
|
|
104
|
-
|
|
105
100
|
def build_eval_candidates(
|
|
106
101
|
df_all: pd.DataFrame,
|
|
107
102
|
user_col: str,
|