nextrec 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,232 @@
1
+ """
2
+ GradNorm loss weighting for multi-task learning.
3
+
4
+ Date: create on 27/10/2025
5
+ Checkpoint: edit on 20/12/2025
6
+ Author: Yang Zhou,zyaztec@gmail.com
7
+
8
+ Reference:
9
+ Chen, Zhao, et al. "GradNorm: Gradient Normalization for Adaptive Loss Balancing
10
+ in Deep Multitask Networks." ICML 2018.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Iterable
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ def get_grad_norm_shared_params(
23
+ model: torch.nn.Module,
24
+ shared_modules: Iterable[str] | None = None,
25
+ ) -> list[torch.nn.Parameter]:
26
+ if not shared_modules:
27
+ return [p for p in model.parameters() if p.requires_grad]
28
+ shared_params = []
29
+ seen = set()
30
+ for name in shared_modules:
31
+ module = getattr(model, name, None)
32
+ if module is None:
33
+ continue
34
+ for param in module.parameters():
35
+ if param.requires_grad and id(param) not in seen:
36
+ shared_params.append(param)
37
+ seen.add(id(param))
38
+ if not shared_params:
39
+ return [p for p in model.parameters() if p.requires_grad]
40
+ return shared_params
41
+
42
+
43
+ class GradNormLossWeighting:
44
+ """
45
+ Adaptive multi-task loss weighting with GradNorm.
46
+
47
+ Args:
48
+ num_tasks: Number of tasks.
49
+ alpha: GradNorm balancing strength.
50
+ lr: Learning rate for the weight optimizer.
51
+ init_weights: Optional initial weights per task.
52
+ device: Torch device for weights.
53
+ ema_decay: Optional EMA decay for smoothing loss ratios.
54
+ init_ema_steps: Number of steps to build EMA for initial losses.
55
+ init_ema_decay: EMA decay for initial losses when init_ema_steps > 0.
56
+ eps: Small value for numerical stability.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ num_tasks: int,
62
+ alpha: float = 1.5,
63
+ lr: float = 0.025,
64
+ init_weights: Iterable[float] | None = None,
65
+ device: torch.device | str | None = None,
66
+ ema_decay: float | None = None,
67
+ init_ema_steps: int = 0,
68
+ init_ema_decay: float = 0.9,
69
+ eps: float = 1e-8,
70
+ ) -> None:
71
+ if num_tasks <= 1:
72
+ raise ValueError("GradNorm requires num_tasks > 1.")
73
+ self.num_tasks = num_tasks
74
+ self.alpha = alpha
75
+ self.eps = eps
76
+ if ema_decay is not None:
77
+ ema_decay = ema_decay
78
+ if ema_decay < 0.0 or ema_decay >= 1.0:
79
+ raise ValueError("ema_decay must be in [0.0, 1.0).")
80
+ self.ema_decay = ema_decay
81
+ self.init_ema_steps = init_ema_steps
82
+ if self.init_ema_steps < 0:
83
+ raise ValueError("init_ema_steps must be >= 0.")
84
+ self.init_ema_decay = init_ema_decay
85
+ if self.init_ema_decay < 0.0 or self.init_ema_decay >= 1.0:
86
+ raise ValueError("init_ema_decay must be in [0.0, 1.0).")
87
+ self.init_ema_count = 0
88
+
89
+ if init_weights is None:
90
+ weights = torch.ones(self.num_tasks, dtype=torch.float32)
91
+ else:
92
+ weights = torch.tensor(list(init_weights), dtype=torch.float32)
93
+ if weights.numel() != self.num_tasks:
94
+ raise ValueError(
95
+ "init_weights length must match num_tasks for GradNorm."
96
+ )
97
+ if device is not None:
98
+ weights = weights.to(device)
99
+ self.weights = nn.Parameter(weights)
100
+ self.optimizer = torch.optim.Adam([self.weights], lr=float(lr))
101
+
102
+ self.initial_losses = None
103
+ self.initial_losses_ema = None
104
+ self.loss_ema = None
105
+ self.pending_grad = None
106
+
107
+ def to(self, device):
108
+ device = torch.device(device)
109
+ self.weights.data = self.weights.data.to(device)
110
+ if self.initial_losses is not None:
111
+ self.initial_losses = self.initial_losses.to(device)
112
+ if self.initial_losses_ema is not None:
113
+ self.initial_losses_ema = self.initial_losses_ema.to(device)
114
+ if self.loss_ema is not None:
115
+ self.loss_ema = self.loss_ema.to(device)
116
+ return self
117
+
118
+ def compute_weighted_loss(
119
+ self,
120
+ task_losses: list[torch.Tensor],
121
+ shared_params: Iterable[torch.nn.Parameter],
122
+ ) -> torch.Tensor:
123
+ """
124
+ Return weighted total loss and update task weights with GradNorm.
125
+ """
126
+ if len(task_losses) != self.num_tasks:
127
+ raise ValueError(
128
+ f"Expected {self.num_tasks} task losses, got {len(task_losses)}."
129
+ )
130
+ shared_params = [p for p in shared_params if p.requires_grad]
131
+ if not shared_params:
132
+ return torch.stack(task_losses).sum()
133
+
134
+ with torch.no_grad():
135
+ loss_values = torch.tensor(
136
+ [loss.item() for loss in task_losses], device=self.weights.device
137
+ )
138
+ if self.initial_losses is None:
139
+ if self.init_ema_steps > 0:
140
+ if self.initial_losses_ema is None:
141
+ self.initial_losses_ema = loss_values
142
+ else:
143
+ self.initial_losses_ema = (
144
+ self.init_ema_decay * self.initial_losses_ema
145
+ + (1.0 - self.init_ema_decay) * loss_values
146
+ )
147
+ self.init_ema_count += 1
148
+ if self.init_ema_count >= self.init_ema_steps:
149
+ self.initial_losses = self.initial_losses_ema.clone()
150
+ else:
151
+ self.initial_losses = loss_values
152
+
153
+ weights_detached = self.weights.detach()
154
+ weighted_losses = [
155
+ weights_detached[i] * task_losses[i] for i in range(self.num_tasks)
156
+ ]
157
+ total_loss = torch.stack(weighted_losses).sum()
158
+
159
+ grad_norms = self.compute_grad_norms(task_losses, shared_params)
160
+ with torch.no_grad():
161
+ if self.ema_decay is not None:
162
+ if self.loss_ema is None:
163
+ self.loss_ema = loss_values
164
+ else:
165
+ self.loss_ema = (
166
+ self.ema_decay * self.loss_ema
167
+ + (1.0 - self.ema_decay) * loss_values
168
+ )
169
+ ratio_source = self.loss_ema
170
+ else:
171
+ ratio_source = loss_values
172
+ if self.initial_losses is not None:
173
+ base_initial = self.initial_losses
174
+ elif self.initial_losses_ema is not None:
175
+ base_initial = self.initial_losses_ema
176
+ else:
177
+ base_initial = loss_values
178
+ loss_ratios = ratio_source / (base_initial + self.eps)
179
+ inv_rate = loss_ratios / (loss_ratios.mean() + self.eps)
180
+ target = grad_norms.mean() * (inv_rate**self.alpha)
181
+
182
+ grad_norm_loss = F.l1_loss(grad_norms, target.detach(), reduction="sum")
183
+ grad_w = torch.autograd.grad(grad_norm_loss, self.weights, retain_graph=True)[0]
184
+ self.pending_grad = grad_w.detach()
185
+
186
+ return total_loss
187
+
188
+ def compute_grad_norms(self, task_losses, shared_params):
189
+ grad_norms = []
190
+ for i, task_loss in enumerate(task_losses):
191
+ grads = torch.autograd.grad(
192
+ self.weights[i] * task_loss,
193
+ shared_params,
194
+ retain_graph=True,
195
+ create_graph=True,
196
+ allow_unused=True,
197
+ )
198
+
199
+ sq_sum = torch.zeros((), device=self.weights.device)
200
+ any_used = False
201
+ for g in grads:
202
+ if g is not None:
203
+ any_used = True
204
+ sq_sum = sq_sum + g.pow(2).sum()
205
+
206
+ if not any_used:
207
+ total_norm = torch.tensor(self.eps, device=self.weights.device)
208
+ else:
209
+ total_norm = torch.sqrt(sq_sum + self.eps)
210
+
211
+ grad_norms.append(total_norm)
212
+
213
+ return torch.stack(grad_norms)
214
+
215
+ def step(self) -> None:
216
+ if self.pending_grad is None:
217
+ return
218
+
219
+ self.optimizer.zero_grad(set_to_none=True)
220
+
221
+ if self.weights.grad is None:
222
+ self.weights.grad = torch.zeros_like(self.weights)
223
+ self.weights.grad.copy_(self.pending_grad)
224
+
225
+ self.optimizer.step()
226
+
227
+ with torch.no_grad():
228
+ w = self.weights.clamp(min=self.eps)
229
+ w = w * self.num_tasks / (w.sum() + self.eps)
230
+ self.weights.copy_(w)
231
+
232
+ self.pending_grad = None
@@ -60,7 +60,7 @@ def build_cb_focal(kw):
60
60
  return ClassBalancedFocalLoss(**kw)
61
61
 
62
62
 
63
- def get_loss_fn(loss: LossType | nn.Module | None = None, **kw) -> nn.Module:
63
+ def get_loss_fn(loss=None, **kw) -> nn.Module:
64
64
  """
65
65
  Get loss function by name or return the provided loss module.
66
66
 
@@ -138,6 +138,7 @@ class ESMM(BaseModel):
138
138
 
139
139
  # CVR tower
140
140
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
141
+ self.grad_norm_shared_modules = ["embedding"]
141
142
  self.prediction_layer = PredictionLayer(
142
143
  task_type=self.default_task, task_dims=[1, 1]
143
144
  )
@@ -165,6 +165,7 @@ class MMOE(BaseModel):
165
165
  for _ in range(self.num_tasks):
166
166
  gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
167
167
  self.gates.append(gate)
168
+ self.grad_norm_shared_modules = ["embedding", "experts", "gates"]
168
169
 
169
170
  # Task-specific towers
170
171
  self.towers = nn.ModuleList()
@@ -295,6 +295,7 @@ class PLE(BaseModel):
295
295
  )
296
296
  self.cgc_layers.append(cgc_layer)
297
297
  expert_output_dim = cgc_layer.output_dim
298
+ self.grad_norm_shared_modules = ["embedding", "cgc_layers"]
298
299
 
299
300
  # Task-specific towers
300
301
  self.towers = nn.ModuleList()
@@ -483,6 +483,10 @@ class POSO(BaseModel):
483
483
  ]
484
484
  )
485
485
  self.tower_heads = None
486
+ if self.architecture == "mlp":
487
+ self.grad_norm_shared_modules = ["embedding"]
488
+ else:
489
+ self.grad_norm_shared_modules = ["embedding", "mmoe"]
486
490
  self.prediction_layer = PredictionLayer(
487
491
  task_type=self.default_task,
488
492
  task_dims=[1] * self.num_tasks,
@@ -129,6 +129,7 @@ class ShareBottom(BaseModel):
129
129
 
130
130
  # Shared bottom network
131
131
  self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
132
+ self.grad_norm_shared_modules = ["embedding", "bottom"]
132
133
 
133
134
  # Get bottom output dimension
134
135
  if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
@@ -41,7 +41,8 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
41
41
  from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
42
42
  from nextrec.basic.model import BaseModel
43
43
 
44
- class EulerInteractionLayerPaper(nn.Module):
44
+
45
+ class EulerInteractionLayer(nn.Module):
45
46
  """
46
47
  Paper-aligned Euler Interaction Layer.
47
48
 
@@ -102,24 +103,32 @@ class EulerInteractionLayerPaper(nn.Module):
102
103
  self.bn = None
103
104
  self.ln = None
104
105
 
105
- def forward(self, r: torch.Tensor, p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
106
+ def forward(
107
+ self, r: torch.Tensor, p: torch.Tensor
108
+ ) -> tuple[torch.Tensor, torch.Tensor]:
106
109
  """
107
110
  r, p: [B, m, d]
108
111
  return r_out, p_out: [B, n, d]
109
112
  """
110
113
  B, m, d = r.shape
111
- assert m == self.m and d == self.d, f"Expected [B,{self.m},{self.d}] got {r.shape}"
114
+ assert (
115
+ m == self.m and d == self.d
116
+ ), f"Expected [B,{self.m},{self.d}] got {r.shape}"
112
117
 
113
118
  # Euler Transformation: rectangular -> polar
114
- lam = torch.sqrt(r * r + p * p + self.eps) # [B,m,d]
115
- theta = torch.atan2(p, r) # [B,m,d]
116
- log_lam = torch.log(lam + self.eps) # [B,m,d]
119
+ lam = torch.sqrt(r * r + p * p + self.eps) # [B,m,d]
120
+ theta = torch.atan2(p, r) # [B,m,d]
121
+ log_lam = torch.log(lam + self.eps) # [B,m,d]
117
122
 
118
123
  # Generalized Multi-order Transformation
119
124
  # psi_k = sum_j alpha_{k,j} * theta_j + delta_k
120
125
  # l_k = exp(sum_j alpha_{k,j} * log(lam_j) + delta'_k)
121
- psi = torch.einsum("bmd,nmd->bnd", theta, self.alpha) + self.delta_phase # [B,n,d]
122
- log_l = torch.einsum("bmd,nmd->bnd", log_lam, self.alpha) + self.delta_logmod # [B,n,d]
126
+ psi = (
127
+ torch.einsum("bmd,nmd->bnd", theta, self.alpha) + self.delta_phase
128
+ ) # [B,n,d]
129
+ log_l = (
130
+ torch.einsum("bmd,nmd->bnd", log_lam, self.alpha) + self.delta_logmod
131
+ ) # [B,n,d]
123
132
  l = torch.exp(log_l) # [B,n,d]
124
133
 
125
134
  # Inverse Euler Transformation
@@ -153,7 +162,7 @@ class EulerInteractionLayerPaper(nn.Module):
153
162
  return r_out, p_out
154
163
 
155
164
 
156
- class ComplexSpaceMappingPaper(nn.Module):
165
+ class ComplexSpaceMapping(nn.Module):
157
166
  """
158
167
  Map real embeddings e_j to complex features via Euler's formula (Eq.6-7).
159
168
  For each field j:
@@ -174,63 +183,6 @@ class ComplexSpaceMappingPaper(nn.Module):
174
183
  r = mu * torch.cos(e)
175
184
  p = mu * torch.sin(e)
176
185
  return r, p
177
-
178
- class EulerNetPaper(nn.Module):
179
- """
180
- Paper-aligned EulerNet core (embedding -> mapping -> L Euler layers -> linear regression).
181
- """
182
-
183
- def __init__(
184
- self,
185
- *,
186
- embedding_dim: int,
187
- num_fields: int,
188
- num_layers: int = 2,
189
- num_orders: int = 8, # n in paper
190
- use_implicit: bool = True,
191
- norm: str | None = "ln", # None | "bn" | "ln"
192
- ):
193
- super().__init__()
194
- self.d = embedding_dim
195
- self.m = num_fields
196
- self.L = num_layers
197
- self.n = num_orders
198
-
199
- self.mapping = ComplexSpaceMappingPaper(embedding_dim, num_fields)
200
-
201
- self.layers = nn.ModuleList([
202
- EulerInteractionLayerPaper(
203
- embedding_dim=embedding_dim,
204
- num_fields=(num_fields if i == 0 else num_orders), # stack: m -> n -> n ...
205
- num_orders=num_orders,
206
- use_implicit=use_implicit,
207
- norm=norm,
208
- )
209
- for i in range(num_layers)
210
- ])
211
-
212
- # Output regression (Eq.16-17)
213
- # After last layer: r,p are [B,n,d]. Concatenate to [B, n*d] each, then regress.
214
- self.w = nn.Linear(self.n * self.d, 1, bias=False) # for real
215
- self.w_im = nn.Linear(self.n * self.d, 1, bias=False) # for imag
216
-
217
- def forward(self, field_emb: torch.Tensor) -> torch.Tensor:
218
- """
219
- field_emb: [B, m, d] real embeddings e_j
220
- return: logits, shape [B,1]
221
- """
222
- r, p = self.mapping(field_emb) # [B,m,d]
223
-
224
- # stack Euler interaction layers
225
- for layer in self.layers:
226
- r, p = layer(r, p) # -> [B,n,d]
227
-
228
- r_flat = r.reshape(r.size(0), self.n * self.d)
229
- p_flat = p.reshape(p.size(0), self.n * self.d)
230
-
231
- z_re = self.w(r_flat)
232
- z_im = self.w_im(p_flat)
233
- return z_re + z_im # Eq.17 logits
234
186
 
235
187
 
236
188
  class EulerNet(BaseModel):
@@ -313,14 +265,23 @@ class EulerNet(BaseModel):
313
265
  "All interaction features must share the same embedding_dim in EulerNet."
314
266
  )
315
267
 
316
- self.euler = EulerNetPaper(
317
- embedding_dim=self.embedding_dim,
318
- num_fields=self.num_fields,
319
- num_layers=num_layers,
320
- num_orders=num_orders,
321
- use_implicit=use_implicit,
322
- norm=norm,
268
+ self.num_layers = num_layers
269
+ self.num_orders = num_orders
270
+ self.mapping = ComplexSpaceMapping(self.embedding_dim, self.num_fields)
271
+ self.layers = nn.ModuleList(
272
+ [
273
+ EulerInteractionLayer(
274
+ embedding_dim=self.embedding_dim,
275
+ num_fields=(self.num_fields if i == 0 else self.num_orders),
276
+ num_orders=self.num_orders,
277
+ use_implicit=use_implicit,
278
+ norm=norm,
279
+ )
280
+ for i in range(self.num_layers)
281
+ ]
323
282
  )
283
+ self.w = nn.Linear(self.num_orders * self.embedding_dim, 1, bias=False)
284
+ self.w_im = nn.Linear(self.num_orders * self.embedding_dim, 1, bias=False)
324
285
 
325
286
  if self.use_linear:
326
287
  if len(self.linear_features) == 0:
@@ -336,7 +297,7 @@ class EulerNet(BaseModel):
336
297
 
337
298
  self.prediction_layer = PredictionLayer(task_type=self.task)
338
299
 
339
- modules = ["euler"]
300
+ modules = ["mapping", "layers", "w", "w_im"]
340
301
  if self.use_linear:
341
302
  modules.append("linear")
342
303
  self.register_regularization_weights(
@@ -354,7 +315,7 @@ class EulerNet(BaseModel):
354
315
  field_emb = self.embedding(
355
316
  x=x, features=self.interaction_features, squeeze_dim=False
356
317
  )
357
- y_euler = self.euler(field_emb)
318
+ y_euler = self.euler_forward(field_emb)
358
319
 
359
320
  if self.use_linear and self.linear is not None:
360
321
  linear_input = self.embedding(
@@ -363,3 +324,11 @@ class EulerNet(BaseModel):
363
324
  y_euler = y_euler + self.linear(linear_input)
364
325
 
365
326
  return self.prediction_layer(y_euler)
327
+
328
+ def euler_forward(self, field_emb: torch.Tensor) -> torch.Tensor:
329
+ r, p = self.mapping(field_emb)
330
+ for layer in self.layers:
331
+ r, p = layer(r, p)
332
+ r_flat = r.reshape(r.size(0), self.num_orders * self.embedding_dim)
333
+ p_flat = p.reshape(p.size(0), self.num_orders * self.embedding_dim)
334
+ return self.w(r_flat) + self.w_im(p_flat)