nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -4
- nextrec/basic/callback.py +39 -87
- nextrec/basic/features.py +149 -28
- nextrec/basic/heads.py +3 -1
- nextrec/basic/layers.py +375 -94
- nextrec/basic/loggers.py +236 -39
- nextrec/basic/model.py +259 -326
- nextrec/basic/session.py +2 -2
- nextrec/basic/summary.py +323 -0
- nextrec/cli.py +3 -3
- nextrec/data/data_processing.py +45 -1
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/__init__.py +0 -4
- nextrec/loss/grad_norm.py +3 -3
- nextrec/models/multi_task/esmm.py +4 -6
- nextrec/models/multi_task/mmoe.py +4 -6
- nextrec/models/multi_task/ple.py +6 -8
- nextrec/models/multi_task/poso.py +5 -7
- nextrec/models/multi_task/share_bottom.py +6 -8
- nextrec/models/ranking/afm.py +4 -6
- nextrec/models/ranking/autoint.py +4 -6
- nextrec/models/ranking/dcn.py +8 -7
- nextrec/models/ranking/dcn_v2.py +4 -6
- nextrec/models/ranking/deepfm.py +5 -7
- nextrec/models/ranking/dien.py +8 -7
- nextrec/models/ranking/din.py +8 -7
- nextrec/models/ranking/eulernet.py +5 -7
- nextrec/models/ranking/ffm.py +5 -7
- nextrec/models/ranking/fibinet.py +4 -6
- nextrec/models/ranking/fm.py +4 -6
- nextrec/models/ranking/lr.py +4 -6
- nextrec/models/ranking/masknet.py +8 -9
- nextrec/models/ranking/pnn.py +4 -6
- nextrec/models/ranking/widedeep.py +5 -7
- nextrec/models/ranking/xdeepfm.py +8 -7
- nextrec/models/retrieval/dssm.py +4 -10
- nextrec/models/retrieval/dssm_v2.py +0 -6
- nextrec/models/retrieval/mind.py +4 -10
- nextrec/models/retrieval/sdm.py +4 -10
- nextrec/models/retrieval/youtube_dnn.py +4 -10
- nextrec/models/sequential/hstu.py +1 -3
- nextrec/utils/__init__.py +17 -15
- nextrec/utils/config.py +15 -5
- nextrec/utils/console.py +2 -2
- nextrec/utils/feature.py +2 -2
- nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
- nextrec/utils/torch_utils.py +57 -112
- nextrec/utils/types.py +63 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
- nextrec-0.4.22.dist-info/RECORD +81 -0
- nextrec-0.4.20.dist-info/RECORD +0 -79
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
|
@@ -50,12 +50,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
50
50
|
num_negative_samples: int = 100,
|
|
51
51
|
temperature: float = 1.0,
|
|
52
52
|
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
dense_l2_reg: float = 0.0,
|
|
58
|
-
early_stop_patience: int = 20,
|
|
53
|
+
embedding_l1_reg=0.0,
|
|
54
|
+
dense_l1_reg=0.0,
|
|
55
|
+
embedding_l2_reg=0.0,
|
|
56
|
+
dense_l2_reg=0.0,
|
|
59
57
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
60
58
|
optimizer_params: dict | None = None,
|
|
61
59
|
scheduler: (
|
|
@@ -81,12 +79,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
81
79
|
num_negative_samples=num_negative_samples,
|
|
82
80
|
temperature=temperature,
|
|
83
81
|
similarity_metric=similarity_metric,
|
|
84
|
-
device=device,
|
|
85
82
|
embedding_l1_reg=embedding_l1_reg,
|
|
86
83
|
dense_l1_reg=dense_l1_reg,
|
|
87
84
|
embedding_l2_reg=embedding_l2_reg,
|
|
88
85
|
dense_l2_reg=dense_l2_reg,
|
|
89
|
-
early_stop_patience=early_stop_patience,
|
|
90
86
|
**kwargs,
|
|
91
87
|
)
|
|
92
88
|
|
|
@@ -169,8 +165,6 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
169
165
|
loss_params=loss_params,
|
|
170
166
|
)
|
|
171
167
|
|
|
172
|
-
self.to(device)
|
|
173
|
-
|
|
174
168
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
175
169
|
"""
|
|
176
170
|
User tower to encode historical behavior sequences and user features.
|
|
@@ -332,7 +332,6 @@ class HSTU(BaseModel):
|
|
|
332
332
|
dense_l1_reg: float = 0.0,
|
|
333
333
|
embedding_l2_reg: float = 0.0,
|
|
334
334
|
dense_l2_reg: float = 0.0,
|
|
335
|
-
device: str = "cpu",
|
|
336
335
|
**kwargs,
|
|
337
336
|
):
|
|
338
337
|
raise NotImplementedError(
|
|
@@ -348,7 +347,7 @@ class HSTU(BaseModel):
|
|
|
348
347
|
)[0]
|
|
349
348
|
|
|
350
349
|
self.hidden_dim = hidden_dim or max(
|
|
351
|
-
int(
|
|
350
|
+
int(self.item_history_feature.embedding_dim or 0), 32
|
|
352
351
|
)
|
|
353
352
|
# Make hidden_dim divisible by num_heads
|
|
354
353
|
if self.hidden_dim % num_heads != 0:
|
|
@@ -368,7 +367,6 @@ class HSTU(BaseModel):
|
|
|
368
367
|
sequence_features=sequence_features,
|
|
369
368
|
target=target,
|
|
370
369
|
task=task or self.default_task,
|
|
371
|
-
device=device,
|
|
372
370
|
embedding_l1_reg=embedding_l1_reg,
|
|
373
371
|
dense_l1_reg=dense_l1_reg,
|
|
374
372
|
embedding_l2_reg=embedding_l2_reg,
|
nextrec/utils/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ Last update: 19/12/2025
|
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from . import console, data, embedding, torch_utils
|
|
9
|
+
from . import console, data, embedding, loss, torch_utils
|
|
10
10
|
from .config import (
|
|
11
11
|
build_feature_objects,
|
|
12
12
|
build_model_instance,
|
|
@@ -14,6 +14,7 @@ from .config import (
|
|
|
14
14
|
load_model_class,
|
|
15
15
|
register_processor_features,
|
|
16
16
|
resolve_path,
|
|
17
|
+
safe_value,
|
|
17
18
|
select_features,
|
|
18
19
|
)
|
|
19
20
|
from .console import (
|
|
@@ -35,23 +36,20 @@ from .data import (
|
|
|
35
36
|
resolve_file_paths,
|
|
36
37
|
)
|
|
37
38
|
from .embedding import get_auto_embedding_dim
|
|
38
|
-
from .feature import
|
|
39
|
+
from .feature import to_list
|
|
39
40
|
from .model import compute_pair_scores, get_mlp_output_dim, merge_features
|
|
41
|
+
from .loss import normalize_task_loss
|
|
40
42
|
from .torch_utils import (
|
|
41
43
|
add_distributed_sampler,
|
|
42
|
-
|
|
43
|
-
configure_device,
|
|
44
|
+
get_device,
|
|
44
45
|
gather_numpy,
|
|
45
|
-
get_device_info,
|
|
46
46
|
get_initializer,
|
|
47
47
|
get_optimizer,
|
|
48
48
|
get_scheduler,
|
|
49
49
|
init_process_group,
|
|
50
|
-
pad_sequence_tensors,
|
|
51
|
-
resolve_device,
|
|
52
|
-
stack_tensors,
|
|
53
50
|
to_tensor,
|
|
54
51
|
)
|
|
52
|
+
from .types import LossName, OptimizerName, SchedulerName, ActivationName
|
|
55
53
|
|
|
56
54
|
__all__ = [
|
|
57
55
|
# Console utilities
|
|
@@ -67,17 +65,12 @@ __all__ = [
|
|
|
67
65
|
# Embedding utilities
|
|
68
66
|
"get_auto_embedding_dim",
|
|
69
67
|
# Device utilities (torch utils)
|
|
70
|
-
"
|
|
71
|
-
"get_device_info",
|
|
72
|
-
"configure_device",
|
|
68
|
+
"get_device",
|
|
73
69
|
"init_process_group",
|
|
74
70
|
"gather_numpy",
|
|
75
71
|
"add_distributed_sampler",
|
|
76
72
|
# Tensor utilities
|
|
77
73
|
"to_tensor",
|
|
78
|
-
"stack_tensors",
|
|
79
|
-
"concat_tensors",
|
|
80
|
-
"pad_sequence_tensors",
|
|
81
74
|
# Data utilities
|
|
82
75
|
"resolve_file_paths",
|
|
83
76
|
"read_table",
|
|
@@ -89,10 +82,13 @@ __all__ = [
|
|
|
89
82
|
"merge_features",
|
|
90
83
|
"get_mlp_output_dim",
|
|
91
84
|
"compute_pair_scores",
|
|
85
|
+
# Loss utilities
|
|
86
|
+
"normalize_task_loss",
|
|
92
87
|
# Feature utilities
|
|
93
|
-
"
|
|
88
|
+
"to_list",
|
|
94
89
|
# Config utilities
|
|
95
90
|
"resolve_path",
|
|
91
|
+
"safe_value",
|
|
96
92
|
"register_processor_features",
|
|
97
93
|
"build_feature_objects",
|
|
98
94
|
"extract_feature_groups",
|
|
@@ -108,5 +104,11 @@ __all__ = [
|
|
|
108
104
|
"console",
|
|
109
105
|
"data",
|
|
110
106
|
"embedding",
|
|
107
|
+
"loss",
|
|
111
108
|
"torch_utils",
|
|
109
|
+
# Type aliases
|
|
110
|
+
"OptimizerName",
|
|
111
|
+
"SchedulerName",
|
|
112
|
+
"LossName",
|
|
113
|
+
"ActivationName",
|
|
112
114
|
]
|
nextrec/utils/config.py
CHANGED
|
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from nextrec.utils.feature import
|
|
24
|
+
from nextrec.utils.feature import to_list
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
27
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
@@ -52,6 +52,16 @@ def resolve_path(
|
|
|
52
52
|
)
|
|
53
53
|
|
|
54
54
|
|
|
55
|
+
def safe_value(value: Any):
|
|
56
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
57
|
+
return value
|
|
58
|
+
if isinstance(value, dict):
|
|
59
|
+
return {str(k): safe_value(v) for k, v in value.items()}
|
|
60
|
+
if isinstance(value, (list, tuple)):
|
|
61
|
+
return [safe_value(v) for v in value]
|
|
62
|
+
return str(value)
|
|
63
|
+
|
|
64
|
+
|
|
55
65
|
def select_features(
|
|
56
66
|
feature_cfg: Dict[str, Any], df_columns: List[str]
|
|
57
67
|
) -> Tuple[List[str], List[str], List[str]]:
|
|
@@ -152,9 +162,9 @@ def build_feature_objects(
|
|
|
152
162
|
dense_features.append(
|
|
153
163
|
DenseFeature(
|
|
154
164
|
name=name,
|
|
155
|
-
|
|
165
|
+
proj_dim=embed_cfg.get("proj_dim"),
|
|
156
166
|
input_dim=embed_cfg.get("input_dim", 1),
|
|
157
|
-
|
|
167
|
+
use_projection=embed_cfg.get("use_projection", False),
|
|
158
168
|
)
|
|
159
169
|
)
|
|
160
170
|
|
|
@@ -239,7 +249,7 @@ def extract_feature_groups(
|
|
|
239
249
|
collected: List[str] = []
|
|
240
250
|
|
|
241
251
|
for group_name, names in feature_groups.items():
|
|
242
|
-
name_list =
|
|
252
|
+
name_list = to_list(names)
|
|
243
253
|
filtered = []
|
|
244
254
|
missing_defined = [n for n in name_list if n not in defined]
|
|
245
255
|
missing_cols = [n for n in name_list if n not in available_cols]
|
|
@@ -441,7 +451,7 @@ def build_model_instance(
|
|
|
441
451
|
direct_features = binding.get("features") or binding.get("feature_names")
|
|
442
452
|
if direct_features and (accepts(param_name) or accepts_var_kwargs):
|
|
443
453
|
init_kwargs[param_name] = _select(
|
|
444
|
-
|
|
454
|
+
to_list(direct_features),
|
|
445
455
|
feature_pool,
|
|
446
456
|
f"feature_bindings.{param_name}",
|
|
447
457
|
)
|
nextrec/utils/console.py
CHANGED
|
@@ -36,7 +36,7 @@ from rich.progress import (
|
|
|
36
36
|
from rich.table import Table
|
|
37
37
|
from rich.text import Text
|
|
38
38
|
|
|
39
|
-
from nextrec.utils.feature import as_float,
|
|
39
|
+
from nextrec.utils.feature import as_float, to_list
|
|
40
40
|
|
|
41
41
|
T = TypeVar("T")
|
|
42
42
|
|
|
@@ -283,7 +283,7 @@ def display_metrics_table(
|
|
|
283
283
|
if not is_main_process:
|
|
284
284
|
return
|
|
285
285
|
|
|
286
|
-
target_list =
|
|
286
|
+
target_list = to_list(target_names)
|
|
287
287
|
task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
|
|
288
288
|
|
|
289
289
|
if isinstance(loss, np.ndarray) and target_list:
|
nextrec/utils/feature.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Feature processing utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 27/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -10,7 +10,7 @@ import numbers
|
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def to_list(value: str | list[str] | None) -> list[str]:
|
|
14
14
|
if value is None:
|
|
15
15
|
return []
|
|
16
16
|
if isinstance(value, str):
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Loss utilities for NextRec.
|
|
3
3
|
|
|
4
|
-
Date: create on
|
|
5
|
-
Checkpoint: edit on 19/12/2025
|
|
4
|
+
Date: create on 28/12/2025
|
|
6
5
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
6
|
"""
|
|
8
7
|
|
|
9
|
-
from
|
|
8
|
+
from __future__ import annotations
|
|
10
9
|
|
|
10
|
+
import torch
|
|
11
11
|
import torch.nn as nn
|
|
12
12
|
|
|
13
13
|
from nextrec.loss.listwise import (
|
|
@@ -19,39 +19,21 @@ from nextrec.loss.listwise import (
|
|
|
19
19
|
)
|
|
20
20
|
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
21
21
|
from nextrec.loss.pointwise import ClassBalancedFocalLoss, FocalLoss, WeightedBCELoss
|
|
22
|
+
from nextrec.utils.types import LossName
|
|
22
23
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
"focal_loss",
|
|
37
|
-
"cb_focal",
|
|
38
|
-
"class_balanced_focal",
|
|
39
|
-
"crossentropy",
|
|
40
|
-
"ce",
|
|
41
|
-
"mse",
|
|
42
|
-
"mae",
|
|
43
|
-
# Pairwise ranking losses
|
|
44
|
-
"bpr",
|
|
45
|
-
"hinge",
|
|
46
|
-
"triplet",
|
|
47
|
-
# Listwise ranking losses
|
|
48
|
-
"sampled_softmax",
|
|
49
|
-
"softmax",
|
|
50
|
-
"infonce",
|
|
51
|
-
"listnet",
|
|
52
|
-
"listmle",
|
|
53
|
-
"approx_ndcg",
|
|
54
|
-
]
|
|
24
|
+
|
|
25
|
+
def normalize_task_loss(
|
|
26
|
+
task_loss,
|
|
27
|
+
valid_count,
|
|
28
|
+
total_count,
|
|
29
|
+
eps=1e-8,
|
|
30
|
+
) -> torch.Tensor:
|
|
31
|
+
if not torch.is_tensor(valid_count):
|
|
32
|
+
valid_count = torch.tensor(float(valid_count), device=task_loss.device)
|
|
33
|
+
if not torch.is_tensor(total_count):
|
|
34
|
+
total_count = torch.tensor(float(total_count), device=task_loss.device)
|
|
35
|
+
scale = valid_count.to(task_loss.dtype) / (total_count.to(task_loss.dtype) + eps)
|
|
36
|
+
return task_loss * scale
|
|
55
37
|
|
|
56
38
|
|
|
57
39
|
def build_cb_focal(kw):
|
|
@@ -60,7 +42,10 @@ def build_cb_focal(kw):
|
|
|
60
42
|
return ClassBalancedFocalLoss(**kw)
|
|
61
43
|
|
|
62
44
|
|
|
63
|
-
def get_loss_fn(
|
|
45
|
+
def get_loss_fn(
|
|
46
|
+
loss: LossName | None | nn.Module = None,
|
|
47
|
+
**kw,
|
|
48
|
+
) -> nn.Module:
|
|
64
49
|
"""
|
|
65
50
|
Get loss function by name or return the provided loss module.
|
|
66
51
|
|
nextrec/utils/torch_utils.py
CHANGED
|
@@ -3,12 +3,16 @@ PyTorch-related utilities for NextRec.
|
|
|
3
3
|
|
|
4
4
|
This module groups device setup, distributed helpers, optimizers/schedulers,
|
|
5
5
|
initialization, and tensor helpers.
|
|
6
|
+
|
|
7
|
+
Date: create on 27/10/2025
|
|
8
|
+
Checkpoint: edit on 27/12/2025
|
|
9
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
10
|
"""
|
|
7
11
|
|
|
8
12
|
from __future__ import annotations
|
|
9
13
|
|
|
10
14
|
import logging
|
|
11
|
-
from typing import Any, Dict, Iterable,
|
|
15
|
+
from typing import Any, Dict, Iterable, Literal
|
|
12
16
|
|
|
13
17
|
import numpy as np
|
|
14
18
|
import torch
|
|
@@ -18,26 +22,25 @@ from torch.utils.data import DataLoader, IterableDataset
|
|
|
18
22
|
from torch.utils.data.distributed import DistributedSampler
|
|
19
23
|
|
|
20
24
|
from nextrec.basic.loggers import colorize
|
|
21
|
-
|
|
22
|
-
KNOWN_NONLINEARITIES: Set[str] = {
|
|
23
|
-
"linear",
|
|
24
|
-
"conv1d",
|
|
25
|
-
"conv2d",
|
|
26
|
-
"conv3d",
|
|
27
|
-
"conv_transpose1d",
|
|
28
|
-
"conv_transpose2d",
|
|
29
|
-
"conv_transpose3d",
|
|
30
|
-
"sigmoid",
|
|
31
|
-
"tanh",
|
|
32
|
-
"relu",
|
|
33
|
-
"leaky_relu",
|
|
34
|
-
"selu",
|
|
35
|
-
"gelu",
|
|
36
|
-
}
|
|
25
|
+
from nextrec.utils.types import OptimizerName, SchedulerName
|
|
37
26
|
|
|
38
27
|
|
|
39
28
|
def resolve_nonlinearity(activation: str) -> str:
|
|
40
|
-
if activation in
|
|
29
|
+
if activation in [
|
|
30
|
+
"linear",
|
|
31
|
+
"conv1d",
|
|
32
|
+
"conv2d",
|
|
33
|
+
"conv3d",
|
|
34
|
+
"conv_transpose1d",
|
|
35
|
+
"conv_transpose2d",
|
|
36
|
+
"conv_transpose3d",
|
|
37
|
+
"sigmoid",
|
|
38
|
+
"tanh",
|
|
39
|
+
"relu",
|
|
40
|
+
"leaky_relu",
|
|
41
|
+
"selu",
|
|
42
|
+
"gelu",
|
|
43
|
+
]:
|
|
41
44
|
return activation
|
|
42
45
|
return "linear"
|
|
43
46
|
|
|
@@ -53,8 +56,30 @@ def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
|
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
def get_initializer(
|
|
56
|
-
init_type:
|
|
57
|
-
|
|
59
|
+
init_type: Literal[
|
|
60
|
+
"xavier_uniform",
|
|
61
|
+
"xavier_normal",
|
|
62
|
+
"kaiming_uniform",
|
|
63
|
+
"kaiming_normal",
|
|
64
|
+
"orthogonal",
|
|
65
|
+
"normal",
|
|
66
|
+
"uniform",
|
|
67
|
+
] = "normal",
|
|
68
|
+
activation: Literal[
|
|
69
|
+
"linear",
|
|
70
|
+
"conv1d",
|
|
71
|
+
"conv2d",
|
|
72
|
+
"conv3d",
|
|
73
|
+
"conv_transpose1d",
|
|
74
|
+
"conv_transpose2d",
|
|
75
|
+
"conv_transpose3d",
|
|
76
|
+
"sigmoid",
|
|
77
|
+
"tanh",
|
|
78
|
+
"relu",
|
|
79
|
+
"leaky_relu",
|
|
80
|
+
"selu",
|
|
81
|
+
"gelu",
|
|
82
|
+
] = "linear",
|
|
58
83
|
param: Dict[str, Any] | None = None,
|
|
59
84
|
):
|
|
60
85
|
param = param or {}
|
|
@@ -89,47 +114,14 @@ def get_initializer(
|
|
|
89
114
|
return initializer_fn
|
|
90
115
|
|
|
91
116
|
|
|
92
|
-
def
|
|
93
|
-
if torch.cuda.is_available():
|
|
94
|
-
return "cuda"
|
|
95
|
-
if torch.backends.mps.is_available():
|
|
96
|
-
import platform
|
|
97
|
-
|
|
98
|
-
mac_ver = platform.mac_ver()[0]
|
|
99
|
-
try:
|
|
100
|
-
major, _ = (int(x) for x in mac_ver.split(".")[:2])
|
|
101
|
-
except Exception:
|
|
102
|
-
major, _ = 0, 0
|
|
103
|
-
if major >= 14:
|
|
104
|
-
return "mps"
|
|
105
|
-
return "cpu"
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def get_device_info() -> dict:
|
|
109
|
-
info = {
|
|
110
|
-
"cuda_available": torch.cuda.is_available(),
|
|
111
|
-
"cuda_device_count": (
|
|
112
|
-
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
113
|
-
),
|
|
114
|
-
"mps_available": torch.backends.mps.is_available(),
|
|
115
|
-
"current_device": resolve_device(),
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
if torch.cuda.is_available():
|
|
119
|
-
info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
|
120
|
-
info["cuda_capability"] = torch.cuda.get_device_capability(0)
|
|
121
|
-
|
|
122
|
-
return info
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def configure_device(
|
|
117
|
+
def get_device(
|
|
126
118
|
distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
|
|
127
119
|
) -> torch.device:
|
|
128
120
|
try:
|
|
129
121
|
device = torch.device(base_device)
|
|
130
122
|
except Exception:
|
|
131
123
|
logging.warning(
|
|
132
|
-
"[
|
|
124
|
+
"[get_device Warning] Invalid base_device, falling back to CPU."
|
|
133
125
|
)
|
|
134
126
|
return torch.device("cpu")
|
|
135
127
|
|
|
@@ -158,7 +150,7 @@ def configure_device(
|
|
|
158
150
|
|
|
159
151
|
|
|
160
152
|
def get_optimizer(
|
|
161
|
-
optimizer:
|
|
153
|
+
optimizer: OptimizerName | torch.optim.Optimizer = "adam",
|
|
162
154
|
params: Iterable[torch.nn.Parameter] | None = None,
|
|
163
155
|
**optimizer_params,
|
|
164
156
|
):
|
|
@@ -191,7 +183,7 @@ def get_optimizer(
|
|
|
191
183
|
|
|
192
184
|
def get_scheduler(
|
|
193
185
|
scheduler: (
|
|
194
|
-
|
|
186
|
+
SchedulerName
|
|
195
187
|
| torch.optim.lr_scheduler._LRScheduler
|
|
196
188
|
| torch.optim.lr_scheduler.LRScheduler
|
|
197
189
|
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
@@ -241,51 +233,6 @@ def to_tensor(
|
|
|
241
233
|
return tensor
|
|
242
234
|
|
|
243
235
|
|
|
244
|
-
def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
245
|
-
if not tensors:
|
|
246
|
-
raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
|
|
247
|
-
return torch.stack(tensors, dim=dim)
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
251
|
-
if not tensors:
|
|
252
|
-
raise ValueError(
|
|
253
|
-
"[Tensor Utils Error] Cannot concatenate empty list of tensors."
|
|
254
|
-
)
|
|
255
|
-
return torch.cat(tensors, dim=dim)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
def pad_sequence_tensors(
|
|
259
|
-
tensors: list[torch.Tensor],
|
|
260
|
-
max_len: int | None = None,
|
|
261
|
-
padding_value: float = 0.0,
|
|
262
|
-
padding_side: str = "right",
|
|
263
|
-
) -> torch.Tensor:
|
|
264
|
-
if not tensors:
|
|
265
|
-
raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
|
|
266
|
-
if max_len is None:
|
|
267
|
-
max_len = max(t.size(0) for t in tensors)
|
|
268
|
-
batch_size = len(tensors)
|
|
269
|
-
padded = torch.full(
|
|
270
|
-
(batch_size, max_len),
|
|
271
|
-
padding_value,
|
|
272
|
-
dtype=tensors[0].dtype,
|
|
273
|
-
device=tensors[0].device,
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
for i, tensor in enumerate(tensors):
|
|
277
|
-
length = min(tensor.size(0), max_len)
|
|
278
|
-
if padding_side == "right":
|
|
279
|
-
padded[i, :length] = tensor[:length]
|
|
280
|
-
elif padding_side == "left":
|
|
281
|
-
padded[i, -length:] = tensor[:length]
|
|
282
|
-
else:
|
|
283
|
-
raise ValueError(
|
|
284
|
-
f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
|
|
285
|
-
)
|
|
286
|
-
return padded
|
|
287
|
-
|
|
288
|
-
|
|
289
236
|
def init_process_group(
|
|
290
237
|
distributed: bool, rank: int, world_size: int, device_id: int | None = None
|
|
291
238
|
) -> None:
|
|
@@ -350,7 +297,7 @@ def add_distributed_sampler(
|
|
|
350
297
|
# return if already has DistributedSampler
|
|
351
298
|
if isinstance(loader.sampler, DistributedSampler):
|
|
352
299
|
return loader, loader.sampler
|
|
353
|
-
dataset =
|
|
300
|
+
dataset = loader.dataset
|
|
354
301
|
if dataset is None:
|
|
355
302
|
return loader, None
|
|
356
303
|
if isinstance(dataset, IterableDataset):
|
|
@@ -379,25 +326,23 @@ def add_distributed_sampler(
|
|
|
379
326
|
"collate_fn": loader.collate_fn,
|
|
380
327
|
"drop_last": drop_last,
|
|
381
328
|
}
|
|
382
|
-
if
|
|
329
|
+
if loader.pin_memory:
|
|
383
330
|
loader_kwargs["pin_memory"] = True
|
|
384
|
-
pin_memory_device =
|
|
331
|
+
pin_memory_device = loader.pin_memory_device
|
|
385
332
|
if pin_memory_device:
|
|
386
333
|
loader_kwargs["pin_memory_device"] = pin_memory_device
|
|
387
|
-
timeout =
|
|
334
|
+
timeout = loader.timeout
|
|
388
335
|
if timeout:
|
|
389
336
|
loader_kwargs["timeout"] = timeout
|
|
390
|
-
worker_init_fn =
|
|
337
|
+
worker_init_fn = loader.worker_init_fn
|
|
391
338
|
if worker_init_fn is not None:
|
|
392
339
|
loader_kwargs["worker_init_fn"] = worker_init_fn
|
|
393
|
-
generator =
|
|
340
|
+
generator = loader.generator
|
|
394
341
|
if generator is not None:
|
|
395
342
|
loader_kwargs["generator"] = generator
|
|
396
343
|
if loader.num_workers > 0:
|
|
397
|
-
loader_kwargs["persistent_workers"] =
|
|
398
|
-
|
|
399
|
-
)
|
|
400
|
-
prefetch_factor = getattr(loader, "prefetch_factor", None)
|
|
344
|
+
loader_kwargs["persistent_workers"] = loader.persistent_workers
|
|
345
|
+
prefetch_factor = loader.prefetch_factor
|
|
401
346
|
if prefetch_factor is not None:
|
|
402
347
|
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
403
348
|
distributed_loader = DataLoader(dataset, **loader_kwargs)
|
nextrec/utils/types.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared type aliases for NextRec.
|
|
3
|
+
|
|
4
|
+
Keep Literal-based public string options centralized to avoid drift.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
OptimizerName = Literal["adam", "sgd", "adamw", "adagrad", "rmsprop"]
|
|
10
|
+
|
|
11
|
+
SchedulerName = Literal["step", "cosine"]
|
|
12
|
+
|
|
13
|
+
LossName = Literal[
|
|
14
|
+
"bce",
|
|
15
|
+
"binary_crossentropy",
|
|
16
|
+
"weighted_bce",
|
|
17
|
+
"focal",
|
|
18
|
+
"focal_loss",
|
|
19
|
+
"cb_focal",
|
|
20
|
+
"class_balanced_focal",
|
|
21
|
+
"crossentropy",
|
|
22
|
+
"ce",
|
|
23
|
+
"mse",
|
|
24
|
+
"mae",
|
|
25
|
+
"bpr",
|
|
26
|
+
"hinge",
|
|
27
|
+
"triplet",
|
|
28
|
+
"sampled_softmax",
|
|
29
|
+
"softmax",
|
|
30
|
+
"infonce",
|
|
31
|
+
"listnet",
|
|
32
|
+
"listmle",
|
|
33
|
+
"approx_ndcg",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
ActivationName = Literal[
|
|
37
|
+
"dice",
|
|
38
|
+
"relu",
|
|
39
|
+
"relu6",
|
|
40
|
+
"elu",
|
|
41
|
+
"selu",
|
|
42
|
+
"leaky_relu",
|
|
43
|
+
"prelu",
|
|
44
|
+
"gelu",
|
|
45
|
+
"sigmoid",
|
|
46
|
+
"tanh",
|
|
47
|
+
"softplus",
|
|
48
|
+
"softsign",
|
|
49
|
+
"hardswish",
|
|
50
|
+
"mish",
|
|
51
|
+
"silu",
|
|
52
|
+
"swish",
|
|
53
|
+
"hardsigmoid",
|
|
54
|
+
"tanhshrink",
|
|
55
|
+
"softshrink",
|
|
56
|
+
"none",
|
|
57
|
+
"linear",
|
|
58
|
+
"identity",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
TrainingModeName = Literal["pointwise", "pairwise", "listwise"]
|
|
62
|
+
|
|
63
|
+
TaskTypeName = Literal["binary", "regression"]
|