nextrec 0.4.12__py3-none-any.whl → 0.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +48 -6
- nextrec/cli.py +1 -0
- nextrec/loss/__init__.py +3 -0
- nextrec/loss/grad_norm.py +232 -0
- nextrec/loss/loss_utils.py +1 -1
- nextrec/models/multi_task/esmm.py +1 -0
- nextrec/models/multi_task/mmoe.py +1 -0
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/poso.py +4 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- {nextrec-0.4.12.dist-info → nextrec-0.4.13.dist-info}/METADATA +6 -6
- {nextrec-0.4.12.dist-info → nextrec-0.4.13.dist-info}/RECORD +16 -15
- {nextrec-0.4.12.dist-info → nextrec-0.4.13.dist-info}/WHEEL +0 -0
- {nextrec-0.4.12.dist-info → nextrec-0.4.13.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.12.dist-info → nextrec-0.4.13.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.13"
|
nextrec/basic/model.py
CHANGED
|
@@ -50,12 +50,14 @@ from nextrec.data.dataloader import (
|
|
|
50
50
|
)
|
|
51
51
|
from nextrec.loss import (
|
|
52
52
|
BPRLoss,
|
|
53
|
+
GradNormLossWeighting,
|
|
53
54
|
HingeLoss,
|
|
54
55
|
InfoNCELoss,
|
|
55
56
|
SampledSoftmaxLoss,
|
|
56
57
|
TripletLoss,
|
|
57
58
|
get_loss_fn,
|
|
58
59
|
)
|
|
60
|
+
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
59
61
|
from nextrec.utils.console import display_metrics_table, progress
|
|
60
62
|
from nextrec.utils.torch_utils import (
|
|
61
63
|
add_distributed_sampler,
|
|
@@ -177,6 +179,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
177
179
|
self.logger_initialized = False
|
|
178
180
|
self.training_logger = None
|
|
179
181
|
self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
|
|
182
|
+
self.grad_norm: GradNormLossWeighting | None = None
|
|
183
|
+
self.grad_norm_shared_params: list[torch.nn.Parameter] | None = None
|
|
180
184
|
|
|
181
185
|
def register_regularization_weights(
|
|
182
186
|
self,
|
|
@@ -377,7 +381,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
377
381
|
scheduler_params: dict | None = None,
|
|
378
382
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
379
383
|
loss_params: dict | list[dict] | None = None,
|
|
380
|
-
loss_weights: int | float | list[int | float] | None = None,
|
|
384
|
+
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
381
385
|
callbacks: list[Callback] | None = None,
|
|
382
386
|
):
|
|
383
387
|
"""
|
|
@@ -390,6 +394,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
390
394
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
391
395
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
392
396
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
397
|
+
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
393
398
|
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
394
399
|
"""
|
|
395
400
|
if loss_params is None:
|
|
@@ -443,7 +448,31 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
443
448
|
for i in range(self.nums_task)
|
|
444
449
|
]
|
|
445
450
|
|
|
446
|
-
|
|
451
|
+
self.grad_norm = None
|
|
452
|
+
self.grad_norm_shared_params = None
|
|
453
|
+
if isinstance(loss_weights, str) and loss_weights.lower() == "grad_norm":
|
|
454
|
+
if self.nums_task == 1:
|
|
455
|
+
raise ValueError(
|
|
456
|
+
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
457
|
+
)
|
|
458
|
+
self.grad_norm = GradNormLossWeighting(
|
|
459
|
+
num_tasks=self.nums_task, device=self.device
|
|
460
|
+
)
|
|
461
|
+
self.loss_weights = None
|
|
462
|
+
elif (
|
|
463
|
+
isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
|
|
464
|
+
):
|
|
465
|
+
if self.nums_task == 1:
|
|
466
|
+
raise ValueError(
|
|
467
|
+
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
468
|
+
)
|
|
469
|
+
grad_norm_params = dict(loss_weights)
|
|
470
|
+
grad_norm_params.pop("method", None)
|
|
471
|
+
self.grad_norm = GradNormLossWeighting(
|
|
472
|
+
num_tasks=self.nums_task, device=self.device, **grad_norm_params
|
|
473
|
+
)
|
|
474
|
+
self.loss_weights = None
|
|
475
|
+
elif loss_weights is None:
|
|
447
476
|
self.loss_weights = None
|
|
448
477
|
elif self.nums_task == 1:
|
|
449
478
|
if isinstance(loss_weights, (list, tuple)):
|
|
@@ -508,9 +537,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
508
537
|
y_pred_i = y_pred[:, start:end]
|
|
509
538
|
y_true_i = y_true[:, start:end]
|
|
510
539
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
511
|
-
if isinstance(self.loss_weights, (list, tuple)):
|
|
512
|
-
task_loss *= self.loss_weights[i]
|
|
513
540
|
task_losses.append(task_loss)
|
|
541
|
+
if self.grad_norm is not None:
|
|
542
|
+
if self.grad_norm_shared_params is None:
|
|
543
|
+
self.grad_norm_shared_params = get_grad_norm_shared_params(
|
|
544
|
+
self, getattr(self, "grad_norm_shared_modules", None)
|
|
545
|
+
)
|
|
546
|
+
return self.grad_norm.compute_weighted_loss(
|
|
547
|
+
task_losses, self.grad_norm_shared_params
|
|
548
|
+
)
|
|
549
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
550
|
+
task_losses = [
|
|
551
|
+
task_loss * self.loss_weights[i]
|
|
552
|
+
for i, task_loss in enumerate(task_losses)
|
|
553
|
+
]
|
|
514
554
|
return torch.stack(task_losses).sum()
|
|
515
555
|
|
|
516
556
|
def prepare_data_loader(
|
|
@@ -1053,6 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1053
1093
|
params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
|
|
1054
1094
|
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
1055
1095
|
self.optimizer_fn.step()
|
|
1096
|
+
if self.grad_norm is not None:
|
|
1097
|
+
self.grad_norm.step()
|
|
1056
1098
|
accumulated_loss += loss.item()
|
|
1057
1099
|
|
|
1058
1100
|
if (
|
|
@@ -1637,7 +1679,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1637
1679
|
add_timestamp=add_timestamp,
|
|
1638
1680
|
)
|
|
1639
1681
|
model_path = Path(target_path)
|
|
1640
|
-
|
|
1682
|
+
|
|
1641
1683
|
ddp_model = getattr(self, "ddp_model", None)
|
|
1642
1684
|
if ddp_model is not None:
|
|
1643
1685
|
model_to_save = ddp_model.module
|
|
@@ -2067,7 +2109,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2067
2109
|
scheduler_params: dict | None = None,
|
|
2068
2110
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
2069
2111
|
loss_params: dict | list[dict] | None = None,
|
|
2070
|
-
loss_weights: int | float | list[int | float] | None = None,
|
|
2112
|
+
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
2071
2113
|
callbacks: list[Callback] | None = None,
|
|
2072
2114
|
):
|
|
2073
2115
|
"""
|
nextrec/cli.py
CHANGED
|
@@ -380,6 +380,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
380
380
|
optimizer_params=train_cfg.get("optimizer_params", {}),
|
|
381
381
|
loss=train_cfg.get("loss", "focal"),
|
|
382
382
|
loss_params=train_cfg.get("loss_params", {}),
|
|
383
|
+
loss_weights=train_cfg.get("loss_weights"),
|
|
383
384
|
)
|
|
384
385
|
|
|
385
386
|
model.fit(
|
nextrec/loss/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from nextrec.loss.listwise import (
|
|
|
5
5
|
ListNetLoss,
|
|
6
6
|
SampledSoftmaxLoss,
|
|
7
7
|
)
|
|
8
|
+
from nextrec.loss.grad_norm import GradNormLossWeighting
|
|
8
9
|
from nextrec.loss.loss_utils import VALID_TASK_TYPES, get_loss_fn, get_loss_kwargs
|
|
9
10
|
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
10
11
|
from nextrec.loss.pointwise import (
|
|
@@ -30,6 +31,8 @@ __all__ = [
|
|
|
30
31
|
"ListNetLoss",
|
|
31
32
|
"ListMLELoss",
|
|
32
33
|
"ApproxNDCGLoss",
|
|
34
|
+
# Multi-task weighting
|
|
35
|
+
"GradNormLossWeighting",
|
|
33
36
|
# Utilities
|
|
34
37
|
"get_loss_fn",
|
|
35
38
|
"get_loss_kwargs",
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GradNorm loss weighting for multi-task learning.
|
|
3
|
+
|
|
4
|
+
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
|
|
8
|
+
Reference:
|
|
9
|
+
Chen, Zhao, et al. "GradNorm: Gradient Normalization for Adaptive Loss Balancing
|
|
10
|
+
in Deep Multitask Networks." ICML 2018.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import Iterable
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_grad_norm_shared_params(
|
|
23
|
+
model: torch.nn.Module,
|
|
24
|
+
shared_modules: Iterable[str] | None = None,
|
|
25
|
+
) -> list[torch.nn.Parameter]:
|
|
26
|
+
if not shared_modules:
|
|
27
|
+
return [p for p in model.parameters() if p.requires_grad]
|
|
28
|
+
shared_params = []
|
|
29
|
+
seen = set()
|
|
30
|
+
for name in shared_modules:
|
|
31
|
+
module = getattr(model, name, None)
|
|
32
|
+
if module is None:
|
|
33
|
+
continue
|
|
34
|
+
for param in module.parameters():
|
|
35
|
+
if param.requires_grad and id(param) not in seen:
|
|
36
|
+
shared_params.append(param)
|
|
37
|
+
seen.add(id(param))
|
|
38
|
+
if not shared_params:
|
|
39
|
+
return [p for p in model.parameters() if p.requires_grad]
|
|
40
|
+
return shared_params
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GradNormLossWeighting:
|
|
44
|
+
"""
|
|
45
|
+
Adaptive multi-task loss weighting with GradNorm.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
num_tasks: Number of tasks.
|
|
49
|
+
alpha: GradNorm balancing strength.
|
|
50
|
+
lr: Learning rate for the weight optimizer.
|
|
51
|
+
init_weights: Optional initial weights per task.
|
|
52
|
+
device: Torch device for weights.
|
|
53
|
+
ema_decay: Optional EMA decay for smoothing loss ratios.
|
|
54
|
+
init_ema_steps: Number of steps to build EMA for initial losses.
|
|
55
|
+
init_ema_decay: EMA decay for initial losses when init_ema_steps > 0.
|
|
56
|
+
eps: Small value for numerical stability.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
num_tasks: int,
|
|
62
|
+
alpha: float = 1.5,
|
|
63
|
+
lr: float = 0.025,
|
|
64
|
+
init_weights: Iterable[float] | None = None,
|
|
65
|
+
device: torch.device | str | None = None,
|
|
66
|
+
ema_decay: float | None = None,
|
|
67
|
+
init_ema_steps: int = 0,
|
|
68
|
+
init_ema_decay: float = 0.9,
|
|
69
|
+
eps: float = 1e-8,
|
|
70
|
+
) -> None:
|
|
71
|
+
if num_tasks <= 1:
|
|
72
|
+
raise ValueError("GradNorm requires num_tasks > 1.")
|
|
73
|
+
self.num_tasks = num_tasks
|
|
74
|
+
self.alpha = alpha
|
|
75
|
+
self.eps = eps
|
|
76
|
+
if ema_decay is not None:
|
|
77
|
+
ema_decay = ema_decay
|
|
78
|
+
if ema_decay < 0.0 or ema_decay >= 1.0:
|
|
79
|
+
raise ValueError("ema_decay must be in [0.0, 1.0).")
|
|
80
|
+
self.ema_decay = ema_decay
|
|
81
|
+
self.init_ema_steps = init_ema_steps
|
|
82
|
+
if self.init_ema_steps < 0:
|
|
83
|
+
raise ValueError("init_ema_steps must be >= 0.")
|
|
84
|
+
self.init_ema_decay = init_ema_decay
|
|
85
|
+
if self.init_ema_decay < 0.0 or self.init_ema_decay >= 1.0:
|
|
86
|
+
raise ValueError("init_ema_decay must be in [0.0, 1.0).")
|
|
87
|
+
self.init_ema_count = 0
|
|
88
|
+
|
|
89
|
+
if init_weights is None:
|
|
90
|
+
weights = torch.ones(self.num_tasks, dtype=torch.float32)
|
|
91
|
+
else:
|
|
92
|
+
weights = torch.tensor(list(init_weights), dtype=torch.float32)
|
|
93
|
+
if weights.numel() != self.num_tasks:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"init_weights length must match num_tasks for GradNorm."
|
|
96
|
+
)
|
|
97
|
+
if device is not None:
|
|
98
|
+
weights = weights.to(device)
|
|
99
|
+
self.weights = nn.Parameter(weights)
|
|
100
|
+
self.optimizer = torch.optim.Adam([self.weights], lr=float(lr))
|
|
101
|
+
|
|
102
|
+
self.initial_losses = None
|
|
103
|
+
self.initial_losses_ema = None
|
|
104
|
+
self.loss_ema = None
|
|
105
|
+
self.pending_grad = None
|
|
106
|
+
|
|
107
|
+
def to(self, device):
|
|
108
|
+
device = torch.device(device)
|
|
109
|
+
self.weights.data = self.weights.data.to(device)
|
|
110
|
+
if self.initial_losses is not None:
|
|
111
|
+
self.initial_losses = self.initial_losses.to(device)
|
|
112
|
+
if self.initial_losses_ema is not None:
|
|
113
|
+
self.initial_losses_ema = self.initial_losses_ema.to(device)
|
|
114
|
+
if self.loss_ema is not None:
|
|
115
|
+
self.loss_ema = self.loss_ema.to(device)
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def compute_weighted_loss(
|
|
119
|
+
self,
|
|
120
|
+
task_losses: list[torch.Tensor],
|
|
121
|
+
shared_params: Iterable[torch.nn.Parameter],
|
|
122
|
+
) -> torch.Tensor:
|
|
123
|
+
"""
|
|
124
|
+
Return weighted total loss and update task weights with GradNorm.
|
|
125
|
+
"""
|
|
126
|
+
if len(task_losses) != self.num_tasks:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Expected {self.num_tasks} task losses, got {len(task_losses)}."
|
|
129
|
+
)
|
|
130
|
+
shared_params = [p for p in shared_params if p.requires_grad]
|
|
131
|
+
if not shared_params:
|
|
132
|
+
return torch.stack(task_losses).sum()
|
|
133
|
+
|
|
134
|
+
with torch.no_grad():
|
|
135
|
+
loss_values = torch.tensor(
|
|
136
|
+
[loss.item() for loss in task_losses], device=self.weights.device
|
|
137
|
+
)
|
|
138
|
+
if self.initial_losses is None:
|
|
139
|
+
if self.init_ema_steps > 0:
|
|
140
|
+
if self.initial_losses_ema is None:
|
|
141
|
+
self.initial_losses_ema = loss_values
|
|
142
|
+
else:
|
|
143
|
+
self.initial_losses_ema = (
|
|
144
|
+
self.init_ema_decay * self.initial_losses_ema
|
|
145
|
+
+ (1.0 - self.init_ema_decay) * loss_values
|
|
146
|
+
)
|
|
147
|
+
self.init_ema_count += 1
|
|
148
|
+
if self.init_ema_count >= self.init_ema_steps:
|
|
149
|
+
self.initial_losses = self.initial_losses_ema.clone()
|
|
150
|
+
else:
|
|
151
|
+
self.initial_losses = loss_values
|
|
152
|
+
|
|
153
|
+
weights_detached = self.weights.detach()
|
|
154
|
+
weighted_losses = [
|
|
155
|
+
weights_detached[i] * task_losses[i] for i in range(self.num_tasks)
|
|
156
|
+
]
|
|
157
|
+
total_loss = torch.stack(weighted_losses).sum()
|
|
158
|
+
|
|
159
|
+
grad_norms = self.compute_grad_norms(task_losses, shared_params)
|
|
160
|
+
with torch.no_grad():
|
|
161
|
+
if self.ema_decay is not None:
|
|
162
|
+
if self.loss_ema is None:
|
|
163
|
+
self.loss_ema = loss_values
|
|
164
|
+
else:
|
|
165
|
+
self.loss_ema = (
|
|
166
|
+
self.ema_decay * self.loss_ema
|
|
167
|
+
+ (1.0 - self.ema_decay) * loss_values
|
|
168
|
+
)
|
|
169
|
+
ratio_source = self.loss_ema
|
|
170
|
+
else:
|
|
171
|
+
ratio_source = loss_values
|
|
172
|
+
if self.initial_losses is not None:
|
|
173
|
+
base_initial = self.initial_losses
|
|
174
|
+
elif self.initial_losses_ema is not None:
|
|
175
|
+
base_initial = self.initial_losses_ema
|
|
176
|
+
else:
|
|
177
|
+
base_initial = loss_values
|
|
178
|
+
loss_ratios = ratio_source / (base_initial + self.eps)
|
|
179
|
+
inv_rate = loss_ratios / (loss_ratios.mean() + self.eps)
|
|
180
|
+
target = grad_norms.mean() * (inv_rate**self.alpha)
|
|
181
|
+
|
|
182
|
+
grad_norm_loss = F.l1_loss(grad_norms, target.detach(), reduction="sum")
|
|
183
|
+
grad_w = torch.autograd.grad(grad_norm_loss, self.weights, retain_graph=True)[0]
|
|
184
|
+
self.pending_grad = grad_w.detach()
|
|
185
|
+
|
|
186
|
+
return total_loss
|
|
187
|
+
|
|
188
|
+
def compute_grad_norms(self, task_losses, shared_params):
|
|
189
|
+
grad_norms = []
|
|
190
|
+
for i, task_loss in enumerate(task_losses):
|
|
191
|
+
grads = torch.autograd.grad(
|
|
192
|
+
self.weights[i] * task_loss,
|
|
193
|
+
shared_params,
|
|
194
|
+
retain_graph=True,
|
|
195
|
+
create_graph=True,
|
|
196
|
+
allow_unused=True,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
sq_sum = torch.zeros((), device=self.weights.device)
|
|
200
|
+
any_used = False
|
|
201
|
+
for g in grads:
|
|
202
|
+
if g is not None:
|
|
203
|
+
any_used = True
|
|
204
|
+
sq_sum = sq_sum + g.pow(2).sum()
|
|
205
|
+
|
|
206
|
+
if not any_used:
|
|
207
|
+
total_norm = torch.tensor(self.eps, device=self.weights.device)
|
|
208
|
+
else:
|
|
209
|
+
total_norm = torch.sqrt(sq_sum + self.eps)
|
|
210
|
+
|
|
211
|
+
grad_norms.append(total_norm)
|
|
212
|
+
|
|
213
|
+
return torch.stack(grad_norms)
|
|
214
|
+
|
|
215
|
+
def step(self) -> None:
|
|
216
|
+
if self.pending_grad is None:
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
220
|
+
|
|
221
|
+
if self.weights.grad is None:
|
|
222
|
+
self.weights.grad = torch.zeros_like(self.weights)
|
|
223
|
+
self.weights.grad.copy_(self.pending_grad)
|
|
224
|
+
|
|
225
|
+
self.optimizer.step()
|
|
226
|
+
|
|
227
|
+
with torch.no_grad():
|
|
228
|
+
w = self.weights.clamp(min=self.eps)
|
|
229
|
+
w = w * self.num_tasks / (w.sum() + self.eps)
|
|
230
|
+
self.weights.copy_(w)
|
|
231
|
+
|
|
232
|
+
self.pending_grad = None
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -138,6 +138,7 @@ class ESMM(BaseModel):
|
|
|
138
138
|
|
|
139
139
|
# CVR tower
|
|
140
140
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
141
|
+
self.grad_norm_shared_modules = ["embedding"]
|
|
141
142
|
self.prediction_layer = PredictionLayer(
|
|
142
143
|
task_type=self.default_task, task_dims=[1, 1]
|
|
143
144
|
)
|
|
@@ -165,6 +165,7 @@ class MMOE(BaseModel):
|
|
|
165
165
|
for _ in range(self.num_tasks):
|
|
166
166
|
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
167
167
|
self.gates.append(gate)
|
|
168
|
+
self.grad_norm_shared_modules = ["embedding", "experts", "gates"]
|
|
168
169
|
|
|
169
170
|
# Task-specific towers
|
|
170
171
|
self.towers = nn.ModuleList()
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -483,6 +483,10 @@ class POSO(BaseModel):
|
|
|
483
483
|
]
|
|
484
484
|
)
|
|
485
485
|
self.tower_heads = None
|
|
486
|
+
if self.architecture == "mlp":
|
|
487
|
+
self.grad_norm_shared_modules = ["embedding"]
|
|
488
|
+
else:
|
|
489
|
+
self.grad_norm_shared_modules = ["embedding", "mmoe"]
|
|
486
490
|
self.prediction_layer = PredictionLayer(
|
|
487
491
|
task_type=self.default_task,
|
|
488
492
|
task_dims=[1] * self.num_tasks,
|
|
@@ -129,6 +129,7 @@ class ShareBottom(BaseModel):
|
|
|
129
129
|
|
|
130
130
|
# Shared bottom network
|
|
131
131
|
self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
|
|
132
|
+
self.grad_norm_shared_modules = ["embedding", "bottom"]
|
|
132
133
|
|
|
133
134
|
# Get bottom output dimension
|
|
134
135
|
if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.13
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -66,7 +66,7 @@ Description-Content-Type: text/markdown
|
|
|
66
66
|

|
|
67
67
|

|
|
68
68
|

|
|
69
|
-

|
|
70
70
|
|
|
71
71
|
中文文档 | [English Version](README_en.md)
|
|
72
72
|
|
|
@@ -99,7 +99,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
99
99
|
|
|
100
100
|
## NextRec近期进展
|
|
101
101
|
|
|
102
|
-
- **12/12/2025** 在v0.4.
|
|
102
|
+
- **12/12/2025** 在v0.4.13中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
103
103
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
104
104
|
- **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
|
|
105
105
|
- **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
|
|
@@ -128,7 +128,7 @@ pip install nextrec # or pip install -e .
|
|
|
128
128
|
- [movielen_ranking_deepfm.py](/tutorials/movielen_ranking_deepfm.py) - movielen 100k数据集上的 DeepFM 模型训练示例
|
|
129
129
|
- [example_ranking_din.py](/tutorials/example_ranking_din.py) - 电商数据集上的DIN 深度兴趣网络训练示例
|
|
130
130
|
- [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
|
|
131
|
-
- [movielen_match_dssm.py](/tutorials/
|
|
131
|
+
- [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
|
|
132
132
|
|
|
133
133
|
- [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
|
|
134
134
|
|
|
@@ -240,11 +240,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
240
240
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
241
241
|
```
|
|
242
242
|
|
|
243
|
-
> 截止当前版本0.4.
|
|
243
|
+
> 截止当前版本0.4.13,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
244
244
|
|
|
245
245
|
## 兼容平台
|
|
246
246
|
|
|
247
|
-
当前最新版本为0.4.
|
|
247
|
+
当前最新版本为0.4.13,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
248
248
|
|
|
249
249
|
| 平台 | 配置 |
|
|
250
250
|
|------|------|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
3
|
-
nextrec/cli.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=ARFl7G-gCe12exBb-FIsJnbsUD5V9okxkHUUdQqb0RA,23
|
|
3
|
+
nextrec/cli.py,sha256=6nBY8O8-0931h428eQS8CALkKn1FmizovJme7Q1c_O0,23978
|
|
4
4
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
nextrec/basic/activation.py,sha256=uzTWfCOtBSkbu_Gk9XBNTj8__s241CaYLJk6l8nGX9I,2885
|
|
6
6
|
nextrec/basic/callback.py,sha256=nn1f8FG9c52vJ-gvwteqPbk3-1QuNS1vmhBlkENdb0I,14636
|
|
@@ -8,7 +8,7 @@ nextrec/basic/features.py,sha256=GyCUzGPuizUofrZSSOdqHK84YhnX4MGTdu7Cx2OGhUA,465
|
|
|
8
8
|
nextrec/basic/layers.py,sha256=ZM3Nka3e2cit3e3peL0ukJCMgKZK1ovNFfAWvVOwlos,28556
|
|
9
9
|
nextrec/basic/loggers.py,sha256=Zh1A5DVAFqlGglyaQ4_IMgvFbWAcXX5H3aHbCWA82nE,6524
|
|
10
10
|
nextrec/basic/metrics.py,sha256=saNgM7kuHk9xqDxZF6x33irTaxeXCU-hxYTUQauuGgg,23074
|
|
11
|
-
nextrec/basic/model.py,sha256=
|
|
11
|
+
nextrec/basic/model.py,sha256=b_O81WSv1XxBAS5oQk92DlLdYAtnikr_epaV5T9RSxs,102570
|
|
12
12
|
nextrec/basic/session.py,sha256=UOG_-EgCOxvqZwCkiEd8sgNV2G1sm_HbzKYVQw8yYDI,4483
|
|
13
13
|
nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
|
|
14
14
|
nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
|
|
@@ -16,19 +16,20 @@ nextrec/data/data_processing.py,sha256=lKXDBszrO5fJMAQetgSPr2mSQuzOluuz1eHV4jp0T
|
|
|
16
16
|
nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
|
|
17
17
|
nextrec/data/dataloader.py,sha256=xTORNbaQVa20sk2S3kyV0SSngscvq8bNqHr0AmYjFqM,18768
|
|
18
18
|
nextrec/data/preprocessor.py,sha256=K-cUP-YdlQx1VJ2m1CXuprncpjDJe2ERVO5xCSoxHKI,44470
|
|
19
|
-
nextrec/loss/__init__.py,sha256
|
|
19
|
+
nextrec/loss/__init__.py,sha256=ZCgsfyR5YAecv6MdOsnUjkfacvZg2coQVjuKAfPvmRo,923
|
|
20
|
+
nextrec/loss/grad_norm.py,sha256=91Grspx95Xu_639TkL_WZRX1xt5QOTZCzBeJWbUGPiE,8385
|
|
20
21
|
nextrec/loss/listwise.py,sha256=UT9vJCOTOQLogVwaeTV7Z5uxIYnngGdxk-p9e97MGkU,5744
|
|
21
|
-
nextrec/loss/loss_utils.py,sha256=
|
|
22
|
+
nextrec/loss/loss_utils.py,sha256=xMmT_tWcKah_xcU3FzVMmSEzyZfxiMKZWUbwkAspcDg,4579
|
|
22
23
|
nextrec/loss/pairwise.py,sha256=X9yg-8pcPt2IWU0AiUhWAt3_4W_3wIF0uSdDYTdoPFY,3398
|
|
23
24
|
nextrec/loss/pointwise.py,sha256=o9J3OznY0hlbDsUXqn3k-BBzYiuUH5dopz8QBFqS_kQ,7343
|
|
24
25
|
nextrec/models/generative/__init__.py,sha256=0MV3P-_ainPaTxmRBGWKUVCEt14KJvuvEHmRB3OQ1Fs,176
|
|
25
26
|
nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
27
|
nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
|
-
nextrec/models/multi_task/esmm.py,sha256=
|
|
28
|
-
nextrec/models/multi_task/mmoe.py,sha256=
|
|
29
|
-
nextrec/models/multi_task/ple.py,sha256=
|
|
30
|
-
nextrec/models/multi_task/poso.py,sha256=
|
|
31
|
-
nextrec/models/multi_task/share_bottom.py,sha256=
|
|
28
|
+
nextrec/models/multi_task/esmm.py,sha256=AqesBZ4tOFNm7POCrHZ90h1zWWSViZAYfydUVOh2dEU,6545
|
|
29
|
+
nextrec/models/multi_task/mmoe.py,sha256=aaQKcx4PL_mAanW3tkjAR886KmMCHTdBuu4p9EIKQJo,8657
|
|
30
|
+
nextrec/models/multi_task/ple.py,sha256=fqkujPFGxxQOO_6nBZEz_UcxLEUoX_vCJsk0YOpxTg4,13084
|
|
31
|
+
nextrec/models/multi_task/poso.py,sha256=J_Btxhm9JpFJMdQQHNNf9mMRHOgO7j1ts6VN5o4qJnk,19193
|
|
32
|
+
nextrec/models/multi_task/share_bottom.py,sha256=DTWm6fpLCLiXimD-qk_0YIKT_9THMFDrnx4GDViXc_g,6583
|
|
32
33
|
nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
34
|
nextrec/models/ranking/afm.py,sha256=96jGUPL4yTWobMIVBjHpOxl9AtAzCAGR8yw7Sy2JmdQ,10125
|
|
34
35
|
nextrec/models/ranking/autoint.py,sha256=S6Cxnp1q2OErSYqmIix5P-b4qLWR-0dY6TMStuU6WLg,8109
|
|
@@ -70,8 +71,8 @@ nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,14
|
|
|
70
71
|
nextrec/utils/feature.py,sha256=rsUAv3ELyDpehVw8nPEEsLCCIjuKGTJJZuFaWB_wrPk,633
|
|
71
72
|
nextrec/utils/model.py,sha256=3B85a0IJCggI26dxv25IX8R_5yQPo7wXI0JIAns6bkQ,1727
|
|
72
73
|
nextrec/utils/torch_utils.py,sha256=AKfYbSOJjEw874xsDB5IO3Ote4X7vnqzt_E0jJny0o8,13468
|
|
73
|
-
nextrec-0.4.
|
|
74
|
-
nextrec-0.4.
|
|
75
|
-
nextrec-0.4.
|
|
76
|
-
nextrec-0.4.
|
|
77
|
-
nextrec-0.4.
|
|
74
|
+
nextrec-0.4.13.dist-info/METADATA,sha256=BcBFpd0l4OdNRlXtG5R1UT-jMcAdloQJjOAG33E4KRE,20958
|
|
75
|
+
nextrec-0.4.13.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
76
|
+
nextrec-0.4.13.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
|
|
77
|
+
nextrec-0.4.13.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
78
|
+
nextrec-0.4.13.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|