nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/utils/__init__.py
CHANGED
|
@@ -10,59 +10,88 @@ This package provides various utility functions organized by category:
|
|
|
10
10
|
- file_utils: File I/O operations
|
|
11
11
|
- model_utils: Model-related utilities
|
|
12
12
|
- feature_utils: Feature processing utilities
|
|
13
|
+
- config_utils: Configuration loading and processing utilities
|
|
13
14
|
|
|
14
15
|
Date: create on 13/11/2025
|
|
15
|
-
Last update:
|
|
16
|
+
Last update: 06/12/2025
|
|
16
17
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
17
18
|
"""
|
|
18
19
|
|
|
20
|
+
from . import optimizer, initializer, embedding
|
|
19
21
|
from .optimizer import get_optimizer, get_scheduler
|
|
20
22
|
from .initializer import get_initializer
|
|
21
23
|
from .embedding import get_auto_embedding_dim
|
|
22
24
|
from .device import resolve_device, get_device_info
|
|
23
25
|
from .tensor import to_tensor, stack_tensors, concat_tensors, pad_sequence_tensors
|
|
24
|
-
from .file import
|
|
26
|
+
from .file import (
|
|
27
|
+
resolve_file_paths,
|
|
28
|
+
read_table,
|
|
29
|
+
load_dataframes,
|
|
30
|
+
iter_file_chunks,
|
|
31
|
+
default_output_dir,
|
|
32
|
+
read_yaml,
|
|
33
|
+
)
|
|
25
34
|
from .model import merge_features, get_mlp_output_dim
|
|
26
35
|
from .feature import normalize_to_list
|
|
27
|
-
from . import
|
|
36
|
+
from .synthetic_data import (
|
|
37
|
+
generate_match_data,
|
|
38
|
+
generate_ranking_data,
|
|
39
|
+
generate_multitask_data,
|
|
40
|
+
generate_distributed_ranking_data,
|
|
41
|
+
)
|
|
42
|
+
from .config import (
|
|
43
|
+
resolve_path,
|
|
44
|
+
select_features,
|
|
45
|
+
register_processor_features,
|
|
46
|
+
build_feature_objects,
|
|
47
|
+
extract_feature_groups,
|
|
48
|
+
load_model_class,
|
|
49
|
+
build_model_instance,
|
|
50
|
+
)
|
|
28
51
|
|
|
29
52
|
__all__ = [
|
|
30
53
|
# Optimizer & Scheduler
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
54
|
+
"get_optimizer",
|
|
55
|
+
"get_scheduler",
|
|
34
56
|
# Initializer
|
|
35
|
-
|
|
36
|
-
|
|
57
|
+
"get_initializer",
|
|
37
58
|
# Embedding
|
|
38
|
-
|
|
39
|
-
|
|
59
|
+
"get_auto_embedding_dim",
|
|
40
60
|
# Device utilities
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
61
|
+
"resolve_device",
|
|
62
|
+
"get_device_info",
|
|
44
63
|
# Tensor utilities
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
64
|
+
"to_tensor",
|
|
65
|
+
"stack_tensors",
|
|
66
|
+
"concat_tensors",
|
|
67
|
+
"pad_sequence_tensors",
|
|
50
68
|
# File utilities
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
69
|
+
"resolve_file_paths",
|
|
70
|
+
"read_table",
|
|
71
|
+
"read_yaml",
|
|
72
|
+
"load_dataframes",
|
|
73
|
+
"iter_file_chunks",
|
|
74
|
+
"default_output_dir",
|
|
57
75
|
# Model utilities
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
76
|
+
"merge_features",
|
|
77
|
+
"get_mlp_output_dim",
|
|
61
78
|
# Feature utilities
|
|
62
|
-
|
|
63
|
-
|
|
79
|
+
"normalize_to_list",
|
|
80
|
+
# Config utilities
|
|
81
|
+
"resolve_path",
|
|
82
|
+
"select_features",
|
|
83
|
+
"register_processor_features",
|
|
84
|
+
"build_feature_objects",
|
|
85
|
+
"extract_feature_groups",
|
|
86
|
+
"load_model_class",
|
|
87
|
+
"build_model_instance",
|
|
88
|
+
# Synthetic data utilities
|
|
89
|
+
"generate_ranking_data",
|
|
90
|
+
"generate_match_data",
|
|
91
|
+
"generate_multitask_data",
|
|
92
|
+
"generate_distributed_ranking_data",
|
|
64
93
|
# Module exports
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
94
|
+
"optimizer",
|
|
95
|
+
"initializer",
|
|
96
|
+
"embedding",
|
|
68
97
|
]
|
nextrec/utils/config.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for NextRec
|
|
3
|
+
|
|
4
|
+
This module provides utilities for loading and processing configuration files,
|
|
5
|
+
including feature configuration, model configuration, and training configuration.
|
|
6
|
+
|
|
7
|
+
Date: create on 06/12/2025
|
|
8
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import importlib
|
|
14
|
+
import importlib.util
|
|
15
|
+
import inspect
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
19
|
+
|
|
20
|
+
from nextrec.utils.feature import normalize_to_list
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
24
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_path(path_str: str | Path, base_dir: Path) -> Path:
|
|
28
|
+
path = Path(path_str).expanduser()
|
|
29
|
+
if path.is_absolute():
|
|
30
|
+
return path
|
|
31
|
+
if path.exists():
|
|
32
|
+
return path.resolve()
|
|
33
|
+
return (base_dir / path).resolve()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def select_features(
|
|
37
|
+
feature_cfg: Dict[str, Any], df_columns: List[str]
|
|
38
|
+
) -> Tuple[List[str], List[str], List[str]]:
|
|
39
|
+
columns = set(df_columns)
|
|
40
|
+
|
|
41
|
+
def pick(group: str) -> List[str]:
|
|
42
|
+
cfg = feature_cfg.get(group, {}) or {}
|
|
43
|
+
names = [name for name in cfg.keys() if name in columns]
|
|
44
|
+
missing = [name for name in cfg.keys() if name not in columns]
|
|
45
|
+
if missing:
|
|
46
|
+
print(f"[feature_config] skipped missing {group} columns: {missing}")
|
|
47
|
+
return names
|
|
48
|
+
|
|
49
|
+
dense_names = pick("dense")
|
|
50
|
+
sparse_names = pick("sparse")
|
|
51
|
+
sequence_names = pick("sequence")
|
|
52
|
+
return dense_names, sparse_names, sequence_names
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def register_processor_features(
|
|
56
|
+
processor: DataProcessor,
|
|
57
|
+
feature_cfg: Dict[str, Any],
|
|
58
|
+
dense_names: List[str],
|
|
59
|
+
sparse_names: List[str],
|
|
60
|
+
sequence_names: List[str],
|
|
61
|
+
) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Register features to DataProcessor based on feature configuration.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
processor: DataProcessor instance
|
|
67
|
+
feature_cfg: Feature configuration dictionary
|
|
68
|
+
dense_names: List of dense feature names
|
|
69
|
+
sparse_names: List of sparse feature names
|
|
70
|
+
sequence_names: List of sequence feature names
|
|
71
|
+
"""
|
|
72
|
+
dense_cfg = feature_cfg.get("dense", {}) or {}
|
|
73
|
+
sparse_cfg = feature_cfg.get("sparse", {}) or {}
|
|
74
|
+
sequence_cfg = feature_cfg.get("sequence", {}) or {}
|
|
75
|
+
|
|
76
|
+
for name in dense_names:
|
|
77
|
+
proc_cfg = dense_cfg.get(name, {}).get("processor_config", {}) or {}
|
|
78
|
+
processor.add_numeric_feature(
|
|
79
|
+
name,
|
|
80
|
+
scaler=proc_cfg.get("scaler", "standard"),
|
|
81
|
+
fill_na=proc_cfg.get("fill_na"),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
for name in sparse_names:
|
|
85
|
+
proc_cfg = sparse_cfg.get(name, {}).get("processor_config", {}) or {}
|
|
86
|
+
processor.add_sparse_feature(
|
|
87
|
+
name,
|
|
88
|
+
encode_method=proc_cfg.get("encode_method", "hash"),
|
|
89
|
+
hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
|
|
90
|
+
fill_na=proc_cfg.get("fill_na", "<UNK>"),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
for name in sequence_names:
|
|
94
|
+
proc_cfg = sequence_cfg.get(name, {}).get("processor_config", {}) or {}
|
|
95
|
+
processor.add_sequence_feature(
|
|
96
|
+
name,
|
|
97
|
+
encode_method=proc_cfg.get("encode_method", "hash"),
|
|
98
|
+
hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
|
|
99
|
+
max_len=proc_cfg.get("max_len", 50),
|
|
100
|
+
pad_value=proc_cfg.get("pad_value", 0),
|
|
101
|
+
truncate=proc_cfg.get("truncate", "post"),
|
|
102
|
+
separator=proc_cfg.get("separator", ","),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def build_feature_objects(
|
|
107
|
+
processor: "DataProcessor",
|
|
108
|
+
feature_cfg: Dict[str, Any],
|
|
109
|
+
dense_names: List[str],
|
|
110
|
+
sparse_names: List[str],
|
|
111
|
+
sequence_names: List[str],
|
|
112
|
+
) -> Tuple[List["DenseFeature"], List["SparseFeature"], List["SequenceFeature"]]:
|
|
113
|
+
"""
|
|
114
|
+
Build feature objects from processor and feature configuration.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
processor: Fitted DataProcessor instance
|
|
118
|
+
feature_cfg: Feature configuration dictionary
|
|
119
|
+
dense_names: List of dense feature names
|
|
120
|
+
sparse_names: List of sparse feature names
|
|
121
|
+
sequence_names: List of sequence feature names
|
|
122
|
+
"""
|
|
123
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
124
|
+
|
|
125
|
+
dense_cfg = feature_cfg.get("dense", {}) or {}
|
|
126
|
+
sparse_cfg = feature_cfg.get("sparse", {}) or {}
|
|
127
|
+
sequence_cfg = feature_cfg.get("sequence", {}) or {}
|
|
128
|
+
vocab_sizes = processor.get_vocab_sizes()
|
|
129
|
+
|
|
130
|
+
dense_features: List[DenseFeature] = []
|
|
131
|
+
for name in dense_names:
|
|
132
|
+
embed_cfg = dense_cfg.get(name, {}).get("embedding_config", {}) or {}
|
|
133
|
+
dense_features.append(
|
|
134
|
+
DenseFeature(
|
|
135
|
+
name=name,
|
|
136
|
+
embedding_dim=embed_cfg.get("embedding_dim"),
|
|
137
|
+
input_dim=embed_cfg.get("input_dim", 1),
|
|
138
|
+
use_embedding=embed_cfg.get("use_embedding", False),
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
sparse_features: List[SparseFeature] = []
|
|
143
|
+
for name in sparse_names:
|
|
144
|
+
entry = sparse_cfg.get(name, {}) or {}
|
|
145
|
+
proc_cfg = entry.get("processor_config", {}) or {}
|
|
146
|
+
embed_cfg = entry.get("embedding_config", {}) or {}
|
|
147
|
+
vocab_size = (
|
|
148
|
+
embed_cfg.get("vocab_size")
|
|
149
|
+
or proc_cfg.get("hash_size")
|
|
150
|
+
or vocab_sizes.get(name, 0)
|
|
151
|
+
or 1
|
|
152
|
+
)
|
|
153
|
+
sparse_features.append(
|
|
154
|
+
SparseFeature(
|
|
155
|
+
name=name,
|
|
156
|
+
vocab_size=int(vocab_size),
|
|
157
|
+
embedding_dim=embed_cfg.get("embedding_dim"),
|
|
158
|
+
padding_idx=embed_cfg.get("padding_idx"),
|
|
159
|
+
l1_reg=embed_cfg.get("l1_reg", 0.0),
|
|
160
|
+
l2_reg=embed_cfg.get("l2_reg", 1e-5),
|
|
161
|
+
trainable=embed_cfg.get("trainable", True),
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
sequence_features: List[SequenceFeature] = []
|
|
166
|
+
for name in sequence_names:
|
|
167
|
+
entry = sequence_cfg.get(name, {}) or {}
|
|
168
|
+
proc_cfg = entry.get("processor_config", {}) or {}
|
|
169
|
+
embed_cfg = entry.get("embedding_config", {}) or {}
|
|
170
|
+
vocab_size = (
|
|
171
|
+
embed_cfg.get("vocab_size")
|
|
172
|
+
or proc_cfg.get("hash_size")
|
|
173
|
+
or vocab_sizes.get(name, 0)
|
|
174
|
+
or 1
|
|
175
|
+
)
|
|
176
|
+
sequence_features.append(
|
|
177
|
+
SequenceFeature(
|
|
178
|
+
name=name,
|
|
179
|
+
vocab_size=int(vocab_size),
|
|
180
|
+
max_len=embed_cfg.get("max_len") or proc_cfg.get("max_len", 50),
|
|
181
|
+
embedding_dim=embed_cfg.get("embedding_dim"),
|
|
182
|
+
padding_idx=embed_cfg.get("padding_idx"),
|
|
183
|
+
combiner=embed_cfg.get("combiner", "mean"),
|
|
184
|
+
l1_reg=embed_cfg.get("l1_reg", 0.0),
|
|
185
|
+
l2_reg=embed_cfg.get("l2_reg", 1e-5),
|
|
186
|
+
trainable=embed_cfg.get("trainable", True),
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return dense_features, sparse_features, sequence_features
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def extract_feature_groups(
|
|
194
|
+
feature_cfg: Dict[str, Any], df_columns: List[str]
|
|
195
|
+
) -> Tuple[Dict[str, List[str]], List[str]]:
|
|
196
|
+
"""
|
|
197
|
+
Extract and validate feature groups from feature configuration.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
feature_cfg: Feature configuration dictionary
|
|
201
|
+
df_columns: Available dataframe columns
|
|
202
|
+
"""
|
|
203
|
+
feature_groups = feature_cfg.get("feature_groups") or {}
|
|
204
|
+
if not feature_groups:
|
|
205
|
+
return {}, []
|
|
206
|
+
|
|
207
|
+
defined = (
|
|
208
|
+
set((feature_cfg.get("dense") or {}).keys())
|
|
209
|
+
| set((feature_cfg.get("sparse") or {}).keys())
|
|
210
|
+
| set((feature_cfg.get("sequence") or {}).keys())
|
|
211
|
+
)
|
|
212
|
+
available_cols = set(df_columns)
|
|
213
|
+
resolved: Dict[str, List[str]] = {}
|
|
214
|
+
collected: List[str] = []
|
|
215
|
+
|
|
216
|
+
for group_name, names in feature_groups.items():
|
|
217
|
+
name_list = normalize_to_list(names)
|
|
218
|
+
filtered = []
|
|
219
|
+
missing_defined = [n for n in name_list if n not in defined]
|
|
220
|
+
missing_cols = [n for n in name_list if n not in available_cols]
|
|
221
|
+
|
|
222
|
+
if missing_defined:
|
|
223
|
+
print(
|
|
224
|
+
f"[feature_config] feature_groups.{group_name} contains features not defined in dense/sparse/sequence: {missing_defined}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
for n in name_list:
|
|
228
|
+
if n in available_cols:
|
|
229
|
+
if n not in filtered:
|
|
230
|
+
filtered.append(n)
|
|
231
|
+
else:
|
|
232
|
+
if n not in missing_cols:
|
|
233
|
+
missing_cols.append(n)
|
|
234
|
+
|
|
235
|
+
if missing_cols:
|
|
236
|
+
print(
|
|
237
|
+
f"[feature_config] feature_groups.{group_name} missing data columns: {missing_cols}"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
resolved[group_name] = filtered
|
|
241
|
+
collected.extend(filtered)
|
|
242
|
+
|
|
243
|
+
return resolved, collected
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def load_model_class(model_cfg: Dict[str, Any], base_dir: Path) -> type:
|
|
247
|
+
"""
|
|
248
|
+
Load model class from configuration.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
model_cfg: Model configuration dictionary
|
|
252
|
+
base_dir: Base directory for resolving relative paths
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def camelize(name: str) -> str:
|
|
256
|
+
"""Convert snake_case or kebab-case to CamelCase."""
|
|
257
|
+
return "".join(
|
|
258
|
+
part.capitalize()
|
|
259
|
+
for part in name.replace("_", " ").replace("-", " ").split()
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
module_path = model_cfg.get("module_path")
|
|
263
|
+
name = model_cfg.get("model") or model_cfg.get("name")
|
|
264
|
+
module_name = model_cfg.get("module") or model_cfg.get("module_name")
|
|
265
|
+
class_name = model_cfg.get("class_name")
|
|
266
|
+
|
|
267
|
+
# Case 1: Custom file path
|
|
268
|
+
if module_path:
|
|
269
|
+
resolved = resolve_path(module_path, base_dir)
|
|
270
|
+
if not resolved.exists():
|
|
271
|
+
raise FileNotFoundError(f"Custom model file not found: {resolved}")
|
|
272
|
+
|
|
273
|
+
spec = importlib.util.spec_from_file_location(resolved.stem, resolved)
|
|
274
|
+
if spec is None or spec.loader is None:
|
|
275
|
+
raise ImportError(f"Unable to load custom model file: {resolved}")
|
|
276
|
+
|
|
277
|
+
module = importlib.util.module_from_spec(spec)
|
|
278
|
+
spec.loader.exec_module(module)
|
|
279
|
+
|
|
280
|
+
if class_name and hasattr(module, class_name):
|
|
281
|
+
return getattr(module, class_name)
|
|
282
|
+
|
|
283
|
+
# Auto-pick first BaseModel subclass
|
|
284
|
+
from nextrec.basic.model import BaseModel
|
|
285
|
+
|
|
286
|
+
for attr in module.__dict__.values():
|
|
287
|
+
if (
|
|
288
|
+
isinstance(attr, type)
|
|
289
|
+
and issubclass(attr, BaseModel)
|
|
290
|
+
and attr is not BaseModel
|
|
291
|
+
):
|
|
292
|
+
return attr
|
|
293
|
+
|
|
294
|
+
raise AttributeError(
|
|
295
|
+
f"No BaseModel subclass found in {resolved}, please provide class_name"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Case 2: Builtin model by short name
|
|
299
|
+
if name and not module_name:
|
|
300
|
+
from nextrec.basic.model import BaseModel
|
|
301
|
+
|
|
302
|
+
candidates = [
|
|
303
|
+
f"nextrec.models.{name.lower()}",
|
|
304
|
+
f"nextrec.models.ranking.{name.lower()}",
|
|
305
|
+
f"nextrec.models.match.{name.lower()}",
|
|
306
|
+
f"nextrec.models.multi_task.{name.lower()}",
|
|
307
|
+
f"nextrec.models.generative.{name.lower()}",
|
|
308
|
+
]
|
|
309
|
+
errors = []
|
|
310
|
+
|
|
311
|
+
for mod in candidates:
|
|
312
|
+
try:
|
|
313
|
+
module = importlib.import_module(mod)
|
|
314
|
+
cls_name = class_name or camelize(name)
|
|
315
|
+
|
|
316
|
+
if hasattr(module, cls_name):
|
|
317
|
+
return getattr(module, cls_name)
|
|
318
|
+
|
|
319
|
+
# Fallback: first BaseModel subclass
|
|
320
|
+
for attr in module.__dict__.values():
|
|
321
|
+
if (
|
|
322
|
+
isinstance(attr, type)
|
|
323
|
+
and issubclass(attr, BaseModel)
|
|
324
|
+
and attr is not BaseModel
|
|
325
|
+
):
|
|
326
|
+
return attr
|
|
327
|
+
|
|
328
|
+
errors.append(f"{mod} missing class {cls_name}")
|
|
329
|
+
except Exception as exc:
|
|
330
|
+
errors.append(f"{mod}: {exc}")
|
|
331
|
+
|
|
332
|
+
raise ImportError(f"Unable to find model for model='{name}'. Tried: {errors}")
|
|
333
|
+
|
|
334
|
+
# Case 3: Explicit module + class
|
|
335
|
+
if module_name and class_name:
|
|
336
|
+
module = importlib.import_module(module_name)
|
|
337
|
+
if not hasattr(module, class_name):
|
|
338
|
+
raise AttributeError(f"Class {class_name} not found in {module_name}")
|
|
339
|
+
return getattr(module, class_name)
|
|
340
|
+
|
|
341
|
+
raise ValueError(
|
|
342
|
+
"model configuration must provide 'model' (builtin name), 'module_path' (custom path), or 'module'+'class_name'"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def build_model_instance(
|
|
347
|
+
model_cfg: Dict[str, Any],
|
|
348
|
+
model_cfg_path: Path,
|
|
349
|
+
dense_features: List[DenseFeature],
|
|
350
|
+
sparse_features: List[SparseFeature],
|
|
351
|
+
sequence_features: List[SequenceFeature],
|
|
352
|
+
target: List[str],
|
|
353
|
+
device: str,
|
|
354
|
+
) -> Any:
|
|
355
|
+
"""
|
|
356
|
+
Build model instance from configuration and feature objects.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
model_cfg: Model configuration dictionary
|
|
360
|
+
model_cfg_path: Path to model config file (for resolving relative paths)
|
|
361
|
+
dense_features: List of dense feature objects
|
|
362
|
+
sparse_features: List of sparse feature objects
|
|
363
|
+
sequence_features: List of sequence feature objects
|
|
364
|
+
target: List of target column names
|
|
365
|
+
device: Device string (e.g., 'cpu', 'cuda:0')
|
|
366
|
+
"""
|
|
367
|
+
dense_map = {f.name: f for f in dense_features}
|
|
368
|
+
sparse_map = {f.name: f for f in sparse_features}
|
|
369
|
+
sequence_map = {f.name: f for f in sequence_features}
|
|
370
|
+
feature_pool: Dict[str, Any] = {**dense_map, **sparse_map, **sequence_map}
|
|
371
|
+
|
|
372
|
+
model_cls = load_model_class(model_cfg, model_cfg_path.parent)
|
|
373
|
+
params_cfg = deepcopy(model_cfg.get("params") or {})
|
|
374
|
+
feature_groups = params_cfg.pop("feature_groups", {}) or {}
|
|
375
|
+
feature_bindings_cfg = (
|
|
376
|
+
model_cfg.get("feature_bindings")
|
|
377
|
+
or params_cfg.pop("feature_bindings", {})
|
|
378
|
+
or {}
|
|
379
|
+
)
|
|
380
|
+
sig_params = inspect.signature(model_cls.__init__).parameters
|
|
381
|
+
|
|
382
|
+
def _select(names: List[str] | None, pool: Dict[str, Any], desc: str) -> List[Any]:
|
|
383
|
+
"""Select features from pool by names."""
|
|
384
|
+
if names is None:
|
|
385
|
+
return list(pool.values())
|
|
386
|
+
missing = [n for n in names if n not in feature_pool]
|
|
387
|
+
if missing:
|
|
388
|
+
raise ValueError(
|
|
389
|
+
f"feature_groups.{desc} contains unknown features: {missing}"
|
|
390
|
+
)
|
|
391
|
+
return [feature_pool[n] for n in names]
|
|
392
|
+
|
|
393
|
+
def accepts(name: str) -> bool:
|
|
394
|
+
"""Check if parameter name is accepted by model __init__."""
|
|
395
|
+
return name in sig_params
|
|
396
|
+
|
|
397
|
+
accepts_var_kwargs = any(
|
|
398
|
+
param.kind == inspect.Parameter.VAR_KEYWORD for param in sig_params.values()
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
init_kwargs: Dict[str, Any] = dict(params_cfg)
|
|
402
|
+
|
|
403
|
+
# Explicit bindings (model_config.feature_bindings) take priority
|
|
404
|
+
for param_name, binding in feature_bindings_cfg.items():
|
|
405
|
+
if param_name in init_kwargs:
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
if isinstance(binding, (list, tuple, set)):
|
|
409
|
+
if accepts(param_name) or accepts_var_kwargs:
|
|
410
|
+
init_kwargs[param_name] = _select(
|
|
411
|
+
list(binding), feature_pool, f"feature_bindings.{param_name}"
|
|
412
|
+
)
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
if isinstance(binding, dict):
|
|
416
|
+
direct_features = binding.get("features") or binding.get("feature_names")
|
|
417
|
+
if direct_features and (accepts(param_name) or accepts_var_kwargs):
|
|
418
|
+
init_kwargs[param_name] = _select(
|
|
419
|
+
normalize_to_list(direct_features),
|
|
420
|
+
feature_pool,
|
|
421
|
+
f"feature_bindings.{param_name}",
|
|
422
|
+
)
|
|
423
|
+
continue
|
|
424
|
+
group_key = binding.get("group") or binding.get("group_key")
|
|
425
|
+
else:
|
|
426
|
+
group_key = binding
|
|
427
|
+
|
|
428
|
+
if group_key not in feature_groups:
|
|
429
|
+
print(
|
|
430
|
+
f"[feature_config] feature_bindings refers to unknown group '{group_key}', skipped"
|
|
431
|
+
)
|
|
432
|
+
continue
|
|
433
|
+
|
|
434
|
+
if accepts(param_name) or accepts_var_kwargs:
|
|
435
|
+
init_kwargs[param_name] = _select(
|
|
436
|
+
feature_groups[group_key], feature_pool, str(group_key)
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Dynamic feature groups: any key in feature_groups that matches __init__ will be filled
|
|
440
|
+
for group_key, names in feature_groups.items():
|
|
441
|
+
if accepts(str(group_key)):
|
|
442
|
+
init_kwargs.setdefault(
|
|
443
|
+
str(group_key), _select(names, feature_pool, str(group_key))
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# Generalized mapping: match params to feature_groups by normalized names
|
|
447
|
+
def _normalize_group_key(key: str) -> str:
|
|
448
|
+
"""Normalize group key by removing common suffixes."""
|
|
449
|
+
key = key.lower()
|
|
450
|
+
for suffix in ("_features", "_feature", "_feats", "_feat", "_list", "_group"):
|
|
451
|
+
if key.endswith(suffix):
|
|
452
|
+
key = key[: -len(suffix)]
|
|
453
|
+
return key
|
|
454
|
+
|
|
455
|
+
normalized_groups = {}
|
|
456
|
+
for gk in feature_groups:
|
|
457
|
+
norm = _normalize_group_key(gk)
|
|
458
|
+
normalized_groups.setdefault(norm, gk)
|
|
459
|
+
|
|
460
|
+
for param_name in sig_params:
|
|
461
|
+
if param_name in ("self",) or param_name in init_kwargs:
|
|
462
|
+
continue
|
|
463
|
+
norm_param = _normalize_group_key(param_name)
|
|
464
|
+
if norm_param in normalized_groups and (
|
|
465
|
+
accepts(param_name) or accepts_var_kwargs
|
|
466
|
+
):
|
|
467
|
+
group_key = normalized_groups[norm_param]
|
|
468
|
+
init_kwargs[param_name] = _select(
|
|
469
|
+
feature_groups[group_key], feature_pool, str(group_key)
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Feature wiring: prefer explicit groups when provided
|
|
473
|
+
if accepts("dense_features"):
|
|
474
|
+
init_kwargs.setdefault("dense_features", dense_features)
|
|
475
|
+
if accepts("sparse_features"):
|
|
476
|
+
init_kwargs.setdefault("sparse_features", sparse_features)
|
|
477
|
+
if accepts("sequence_features"):
|
|
478
|
+
init_kwargs.setdefault("sequence_features", sequence_features)
|
|
479
|
+
|
|
480
|
+
if accepts("target"):
|
|
481
|
+
init_kwargs.setdefault("target", target)
|
|
482
|
+
if accepts("device"):
|
|
483
|
+
init_kwargs.setdefault("device", device)
|
|
484
|
+
|
|
485
|
+
# Pass session_id if model accepts it
|
|
486
|
+
if "session_id" not in init_kwargs and model_cfg.get("session_id") is not None:
|
|
487
|
+
if accepts("session_id") or accepts_var_kwargs:
|
|
488
|
+
init_kwargs["session_id"] = model_cfg.get("session_id")
|
|
489
|
+
|
|
490
|
+
return model_cls(**init_kwargs)
|
nextrec/utils/device.py
CHANGED
|
@@ -2,12 +2,13 @@
|
|
|
2
2
|
Device management utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 06/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
|
-
|
|
8
|
+
|
|
8
9
|
import torch
|
|
9
10
|
import platform
|
|
10
|
-
import
|
|
11
|
+
import logging
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def resolve_device() -> str:
|
|
@@ -16,23 +17,62 @@ def resolve_device() -> str:
|
|
|
16
17
|
if torch.backends.mps.is_available():
|
|
17
18
|
mac_ver = platform.mac_ver()[0]
|
|
18
19
|
try:
|
|
19
|
-
major,
|
|
20
|
+
major, _ = (int(x) for x in mac_ver.split(".")[:2])
|
|
20
21
|
except Exception:
|
|
21
|
-
major,
|
|
22
|
+
major, _ = 0, 0
|
|
22
23
|
if major >= 14:
|
|
23
24
|
return "mps"
|
|
24
25
|
return "cpu"
|
|
25
26
|
|
|
27
|
+
|
|
26
28
|
def get_device_info() -> dict:
|
|
27
29
|
info = {
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
"cuda_available": torch.cuda.is_available(),
|
|
31
|
+
"cuda_device_count": (
|
|
32
|
+
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
33
|
+
),
|
|
34
|
+
"mps_available": torch.backends.mps.is_available(),
|
|
35
|
+
"current_device": resolve_device(),
|
|
32
36
|
}
|
|
33
|
-
|
|
37
|
+
|
|
34
38
|
if torch.cuda.is_available():
|
|
35
|
-
info[
|
|
36
|
-
info[
|
|
37
|
-
|
|
39
|
+
info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
|
40
|
+
info["cuda_capability"] = torch.cuda.get_device_capability(0)
|
|
41
|
+
|
|
38
42
|
return info
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def configure_device(
|
|
46
|
+
distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
|
|
47
|
+
) -> torch.device:
|
|
48
|
+
try:
|
|
49
|
+
device = torch.device(base_device)
|
|
50
|
+
except Exception:
|
|
51
|
+
logging.warning(
|
|
52
|
+
"[configure_device Warning] Invalid base_device, falling back to CPU."
|
|
53
|
+
)
|
|
54
|
+
return torch.device("cpu")
|
|
55
|
+
|
|
56
|
+
if distributed:
|
|
57
|
+
if device.type == "cuda":
|
|
58
|
+
if not torch.cuda.is_available():
|
|
59
|
+
logging.warning(
|
|
60
|
+
"[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
|
|
61
|
+
)
|
|
62
|
+
return torch.device("cpu")
|
|
63
|
+
if not (0 <= local_rank < torch.cuda.device_count()):
|
|
64
|
+
logging.warning(
|
|
65
|
+
f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
|
|
66
|
+
)
|
|
67
|
+
return torch.device("cpu")
|
|
68
|
+
try:
|
|
69
|
+
torch.cuda.set_device(local_rank)
|
|
70
|
+
return torch.device(f"cuda:{local_rank}")
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
logging.warning(
|
|
73
|
+
f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
|
|
74
|
+
)
|
|
75
|
+
return torch.device("cpu")
|
|
76
|
+
else:
|
|
77
|
+
return torch.device("cpu")
|
|
78
|
+
return device
|