nextrec 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +44 -54
- nextrec/basic/features.py +35 -22
- nextrec/basic/layers.py +64 -68
- nextrec/basic/loggers.py +2 -2
- nextrec/basic/metrics.py +9 -5
- nextrec/basic/model.py +208 -110
- nextrec/cli.py +17 -5
- nextrec/data/preprocessor.py +4 -4
- nextrec/loss/__init__.py +3 -0
- nextrec/loss/grad_norm.py +232 -0
- nextrec/loss/loss_utils.py +1 -1
- nextrec/models/multi_task/esmm.py +1 -0
- nextrec/models/multi_task/mmoe.py +1 -0
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/poso.py +4 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/ranking/eulernet.py +44 -75
- nextrec/models/ranking/ffm.py +275 -0
- nextrec/models/ranking/lr.py +1 -3
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/console.py +9 -1
- nextrec/utils/model.py +14 -0
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/METADATA +7 -7
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/RECORD +28 -27
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/WHEEL +0 -0
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GradNorm loss weighting for multi-task learning.
|
|
3
|
+
|
|
4
|
+
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
|
|
8
|
+
Reference:
|
|
9
|
+
Chen, Zhao, et al. "GradNorm: Gradient Normalization for Adaptive Loss Balancing
|
|
10
|
+
in Deep Multitask Networks." ICML 2018.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import Iterable
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_grad_norm_shared_params(
|
|
23
|
+
model: torch.nn.Module,
|
|
24
|
+
shared_modules: Iterable[str] | None = None,
|
|
25
|
+
) -> list[torch.nn.Parameter]:
|
|
26
|
+
if not shared_modules:
|
|
27
|
+
return [p for p in model.parameters() if p.requires_grad]
|
|
28
|
+
shared_params = []
|
|
29
|
+
seen = set()
|
|
30
|
+
for name in shared_modules:
|
|
31
|
+
module = getattr(model, name, None)
|
|
32
|
+
if module is None:
|
|
33
|
+
continue
|
|
34
|
+
for param in module.parameters():
|
|
35
|
+
if param.requires_grad and id(param) not in seen:
|
|
36
|
+
shared_params.append(param)
|
|
37
|
+
seen.add(id(param))
|
|
38
|
+
if not shared_params:
|
|
39
|
+
return [p for p in model.parameters() if p.requires_grad]
|
|
40
|
+
return shared_params
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GradNormLossWeighting:
|
|
44
|
+
"""
|
|
45
|
+
Adaptive multi-task loss weighting with GradNorm.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
num_tasks: Number of tasks.
|
|
49
|
+
alpha: GradNorm balancing strength.
|
|
50
|
+
lr: Learning rate for the weight optimizer.
|
|
51
|
+
init_weights: Optional initial weights per task.
|
|
52
|
+
device: Torch device for weights.
|
|
53
|
+
ema_decay: Optional EMA decay for smoothing loss ratios.
|
|
54
|
+
init_ema_steps: Number of steps to build EMA for initial losses.
|
|
55
|
+
init_ema_decay: EMA decay for initial losses when init_ema_steps > 0.
|
|
56
|
+
eps: Small value for numerical stability.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
num_tasks: int,
|
|
62
|
+
alpha: float = 1.5,
|
|
63
|
+
lr: float = 0.025,
|
|
64
|
+
init_weights: Iterable[float] | None = None,
|
|
65
|
+
device: torch.device | str | None = None,
|
|
66
|
+
ema_decay: float | None = None,
|
|
67
|
+
init_ema_steps: int = 0,
|
|
68
|
+
init_ema_decay: float = 0.9,
|
|
69
|
+
eps: float = 1e-8,
|
|
70
|
+
) -> None:
|
|
71
|
+
if num_tasks <= 1:
|
|
72
|
+
raise ValueError("GradNorm requires num_tasks > 1.")
|
|
73
|
+
self.num_tasks = num_tasks
|
|
74
|
+
self.alpha = alpha
|
|
75
|
+
self.eps = eps
|
|
76
|
+
if ema_decay is not None:
|
|
77
|
+
ema_decay = ema_decay
|
|
78
|
+
if ema_decay < 0.0 or ema_decay >= 1.0:
|
|
79
|
+
raise ValueError("ema_decay must be in [0.0, 1.0).")
|
|
80
|
+
self.ema_decay = ema_decay
|
|
81
|
+
self.init_ema_steps = init_ema_steps
|
|
82
|
+
if self.init_ema_steps < 0:
|
|
83
|
+
raise ValueError("init_ema_steps must be >= 0.")
|
|
84
|
+
self.init_ema_decay = init_ema_decay
|
|
85
|
+
if self.init_ema_decay < 0.0 or self.init_ema_decay >= 1.0:
|
|
86
|
+
raise ValueError("init_ema_decay must be in [0.0, 1.0).")
|
|
87
|
+
self.init_ema_count = 0
|
|
88
|
+
|
|
89
|
+
if init_weights is None:
|
|
90
|
+
weights = torch.ones(self.num_tasks, dtype=torch.float32)
|
|
91
|
+
else:
|
|
92
|
+
weights = torch.tensor(list(init_weights), dtype=torch.float32)
|
|
93
|
+
if weights.numel() != self.num_tasks:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"init_weights length must match num_tasks for GradNorm."
|
|
96
|
+
)
|
|
97
|
+
if device is not None:
|
|
98
|
+
weights = weights.to(device)
|
|
99
|
+
self.weights = nn.Parameter(weights)
|
|
100
|
+
self.optimizer = torch.optim.Adam([self.weights], lr=float(lr))
|
|
101
|
+
|
|
102
|
+
self.initial_losses = None
|
|
103
|
+
self.initial_losses_ema = None
|
|
104
|
+
self.loss_ema = None
|
|
105
|
+
self.pending_grad = None
|
|
106
|
+
|
|
107
|
+
def to(self, device):
|
|
108
|
+
device = torch.device(device)
|
|
109
|
+
self.weights.data = self.weights.data.to(device)
|
|
110
|
+
if self.initial_losses is not None:
|
|
111
|
+
self.initial_losses = self.initial_losses.to(device)
|
|
112
|
+
if self.initial_losses_ema is not None:
|
|
113
|
+
self.initial_losses_ema = self.initial_losses_ema.to(device)
|
|
114
|
+
if self.loss_ema is not None:
|
|
115
|
+
self.loss_ema = self.loss_ema.to(device)
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def compute_weighted_loss(
|
|
119
|
+
self,
|
|
120
|
+
task_losses: list[torch.Tensor],
|
|
121
|
+
shared_params: Iterable[torch.nn.Parameter],
|
|
122
|
+
) -> torch.Tensor:
|
|
123
|
+
"""
|
|
124
|
+
Return weighted total loss and update task weights with GradNorm.
|
|
125
|
+
"""
|
|
126
|
+
if len(task_losses) != self.num_tasks:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Expected {self.num_tasks} task losses, got {len(task_losses)}."
|
|
129
|
+
)
|
|
130
|
+
shared_params = [p for p in shared_params if p.requires_grad]
|
|
131
|
+
if not shared_params:
|
|
132
|
+
return torch.stack(task_losses).sum()
|
|
133
|
+
|
|
134
|
+
with torch.no_grad():
|
|
135
|
+
loss_values = torch.tensor(
|
|
136
|
+
[loss.item() for loss in task_losses], device=self.weights.device
|
|
137
|
+
)
|
|
138
|
+
if self.initial_losses is None:
|
|
139
|
+
if self.init_ema_steps > 0:
|
|
140
|
+
if self.initial_losses_ema is None:
|
|
141
|
+
self.initial_losses_ema = loss_values
|
|
142
|
+
else:
|
|
143
|
+
self.initial_losses_ema = (
|
|
144
|
+
self.init_ema_decay * self.initial_losses_ema
|
|
145
|
+
+ (1.0 - self.init_ema_decay) * loss_values
|
|
146
|
+
)
|
|
147
|
+
self.init_ema_count += 1
|
|
148
|
+
if self.init_ema_count >= self.init_ema_steps:
|
|
149
|
+
self.initial_losses = self.initial_losses_ema.clone()
|
|
150
|
+
else:
|
|
151
|
+
self.initial_losses = loss_values
|
|
152
|
+
|
|
153
|
+
weights_detached = self.weights.detach()
|
|
154
|
+
weighted_losses = [
|
|
155
|
+
weights_detached[i] * task_losses[i] for i in range(self.num_tasks)
|
|
156
|
+
]
|
|
157
|
+
total_loss = torch.stack(weighted_losses).sum()
|
|
158
|
+
|
|
159
|
+
grad_norms = self.compute_grad_norms(task_losses, shared_params)
|
|
160
|
+
with torch.no_grad():
|
|
161
|
+
if self.ema_decay is not None:
|
|
162
|
+
if self.loss_ema is None:
|
|
163
|
+
self.loss_ema = loss_values
|
|
164
|
+
else:
|
|
165
|
+
self.loss_ema = (
|
|
166
|
+
self.ema_decay * self.loss_ema
|
|
167
|
+
+ (1.0 - self.ema_decay) * loss_values
|
|
168
|
+
)
|
|
169
|
+
ratio_source = self.loss_ema
|
|
170
|
+
else:
|
|
171
|
+
ratio_source = loss_values
|
|
172
|
+
if self.initial_losses is not None:
|
|
173
|
+
base_initial = self.initial_losses
|
|
174
|
+
elif self.initial_losses_ema is not None:
|
|
175
|
+
base_initial = self.initial_losses_ema
|
|
176
|
+
else:
|
|
177
|
+
base_initial = loss_values
|
|
178
|
+
loss_ratios = ratio_source / (base_initial + self.eps)
|
|
179
|
+
inv_rate = loss_ratios / (loss_ratios.mean() + self.eps)
|
|
180
|
+
target = grad_norms.mean() * (inv_rate**self.alpha)
|
|
181
|
+
|
|
182
|
+
grad_norm_loss = F.l1_loss(grad_norms, target.detach(), reduction="sum")
|
|
183
|
+
grad_w = torch.autograd.grad(grad_norm_loss, self.weights, retain_graph=True)[0]
|
|
184
|
+
self.pending_grad = grad_w.detach()
|
|
185
|
+
|
|
186
|
+
return total_loss
|
|
187
|
+
|
|
188
|
+
def compute_grad_norms(self, task_losses, shared_params):
|
|
189
|
+
grad_norms = []
|
|
190
|
+
for i, task_loss in enumerate(task_losses):
|
|
191
|
+
grads = torch.autograd.grad(
|
|
192
|
+
self.weights[i] * task_loss,
|
|
193
|
+
shared_params,
|
|
194
|
+
retain_graph=True,
|
|
195
|
+
create_graph=True,
|
|
196
|
+
allow_unused=True,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
sq_sum = torch.zeros((), device=self.weights.device)
|
|
200
|
+
any_used = False
|
|
201
|
+
for g in grads:
|
|
202
|
+
if g is not None:
|
|
203
|
+
any_used = True
|
|
204
|
+
sq_sum = sq_sum + g.pow(2).sum()
|
|
205
|
+
|
|
206
|
+
if not any_used:
|
|
207
|
+
total_norm = torch.tensor(self.eps, device=self.weights.device)
|
|
208
|
+
else:
|
|
209
|
+
total_norm = torch.sqrt(sq_sum + self.eps)
|
|
210
|
+
|
|
211
|
+
grad_norms.append(total_norm)
|
|
212
|
+
|
|
213
|
+
return torch.stack(grad_norms)
|
|
214
|
+
|
|
215
|
+
def step(self) -> None:
|
|
216
|
+
if self.pending_grad is None:
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
220
|
+
|
|
221
|
+
if self.weights.grad is None:
|
|
222
|
+
self.weights.grad = torch.zeros_like(self.weights)
|
|
223
|
+
self.weights.grad.copy_(self.pending_grad)
|
|
224
|
+
|
|
225
|
+
self.optimizer.step()
|
|
226
|
+
|
|
227
|
+
with torch.no_grad():
|
|
228
|
+
w = self.weights.clamp(min=self.eps)
|
|
229
|
+
w = w * self.num_tasks / (w.sum() + self.eps)
|
|
230
|
+
self.weights.copy_(w)
|
|
231
|
+
|
|
232
|
+
self.pending_grad = None
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -60,7 +60,7 @@ def build_cb_focal(kw):
|
|
|
60
60
|
return ClassBalancedFocalLoss(**kw)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def get_loss_fn(loss
|
|
63
|
+
def get_loss_fn(loss=None, **kw) -> nn.Module:
|
|
64
64
|
"""
|
|
65
65
|
Get loss function by name or return the provided loss module.
|
|
66
66
|
|
|
@@ -138,6 +138,7 @@ class ESMM(BaseModel):
|
|
|
138
138
|
|
|
139
139
|
# CVR tower
|
|
140
140
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
141
|
+
self.grad_norm_shared_modules = ["embedding"]
|
|
141
142
|
self.prediction_layer = PredictionLayer(
|
|
142
143
|
task_type=self.default_task, task_dims=[1, 1]
|
|
143
144
|
)
|
|
@@ -165,6 +165,7 @@ class MMOE(BaseModel):
|
|
|
165
165
|
for _ in range(self.num_tasks):
|
|
166
166
|
gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
|
|
167
167
|
self.gates.append(gate)
|
|
168
|
+
self.grad_norm_shared_modules = ["embedding", "experts", "gates"]
|
|
168
169
|
|
|
169
170
|
# Task-specific towers
|
|
170
171
|
self.towers = nn.ModuleList()
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -483,6 +483,10 @@ class POSO(BaseModel):
|
|
|
483
483
|
]
|
|
484
484
|
)
|
|
485
485
|
self.tower_heads = None
|
|
486
|
+
if self.architecture == "mlp":
|
|
487
|
+
self.grad_norm_shared_modules = ["embedding"]
|
|
488
|
+
else:
|
|
489
|
+
self.grad_norm_shared_modules = ["embedding", "mmoe"]
|
|
486
490
|
self.prediction_layer = PredictionLayer(
|
|
487
491
|
task_type=self.default_task,
|
|
488
492
|
task_dims=[1] * self.num_tasks,
|
|
@@ -129,6 +129,7 @@ class ShareBottom(BaseModel):
|
|
|
129
129
|
|
|
130
130
|
# Shared bottom network
|
|
131
131
|
self.bottom = MLP(input_dim=input_dim, output_layer=False, **bottom_params)
|
|
132
|
+
self.grad_norm_shared_modules = ["embedding", "bottom"]
|
|
132
133
|
|
|
133
134
|
# Get bottom output dimension
|
|
134
135
|
if "dims" in bottom_params and len(bottom_params["dims"]) > 0:
|
|
@@ -41,7 +41,8 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
41
41
|
from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
|
|
42
42
|
from nextrec.basic.model import BaseModel
|
|
43
43
|
|
|
44
|
-
|
|
44
|
+
|
|
45
|
+
class EulerInteractionLayer(nn.Module):
|
|
45
46
|
"""
|
|
46
47
|
Paper-aligned Euler Interaction Layer.
|
|
47
48
|
|
|
@@ -102,24 +103,32 @@ class EulerInteractionLayerPaper(nn.Module):
|
|
|
102
103
|
self.bn = None
|
|
103
104
|
self.ln = None
|
|
104
105
|
|
|
105
|
-
def forward(
|
|
106
|
+
def forward(
|
|
107
|
+
self, r: torch.Tensor, p: torch.Tensor
|
|
108
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
106
109
|
"""
|
|
107
110
|
r, p: [B, m, d]
|
|
108
111
|
return r_out, p_out: [B, n, d]
|
|
109
112
|
"""
|
|
110
113
|
B, m, d = r.shape
|
|
111
|
-
assert
|
|
114
|
+
assert (
|
|
115
|
+
m == self.m and d == self.d
|
|
116
|
+
), f"Expected [B,{self.m},{self.d}] got {r.shape}"
|
|
112
117
|
|
|
113
118
|
# Euler Transformation: rectangular -> polar
|
|
114
|
-
lam = torch.sqrt(r * r + p * p + self.eps)
|
|
115
|
-
theta = torch.atan2(p, r)
|
|
116
|
-
log_lam = torch.log(lam + self.eps)
|
|
119
|
+
lam = torch.sqrt(r * r + p * p + self.eps) # [B,m,d]
|
|
120
|
+
theta = torch.atan2(p, r) # [B,m,d]
|
|
121
|
+
log_lam = torch.log(lam + self.eps) # [B,m,d]
|
|
117
122
|
|
|
118
123
|
# Generalized Multi-order Transformation
|
|
119
124
|
# psi_k = sum_j alpha_{k,j} * theta_j + delta_k
|
|
120
125
|
# l_k = exp(sum_j alpha_{k,j} * log(lam_j) + delta'_k)
|
|
121
|
-
psi =
|
|
122
|
-
|
|
126
|
+
psi = (
|
|
127
|
+
torch.einsum("bmd,nmd->bnd", theta, self.alpha) + self.delta_phase
|
|
128
|
+
) # [B,n,d]
|
|
129
|
+
log_l = (
|
|
130
|
+
torch.einsum("bmd,nmd->bnd", log_lam, self.alpha) + self.delta_logmod
|
|
131
|
+
) # [B,n,d]
|
|
123
132
|
l = torch.exp(log_l) # [B,n,d]
|
|
124
133
|
|
|
125
134
|
# Inverse Euler Transformation
|
|
@@ -153,7 +162,7 @@ class EulerInteractionLayerPaper(nn.Module):
|
|
|
153
162
|
return r_out, p_out
|
|
154
163
|
|
|
155
164
|
|
|
156
|
-
class
|
|
165
|
+
class ComplexSpaceMapping(nn.Module):
|
|
157
166
|
"""
|
|
158
167
|
Map real embeddings e_j to complex features via Euler's formula (Eq.6-7).
|
|
159
168
|
For each field j:
|
|
@@ -174,63 +183,6 @@ class ComplexSpaceMappingPaper(nn.Module):
|
|
|
174
183
|
r = mu * torch.cos(e)
|
|
175
184
|
p = mu * torch.sin(e)
|
|
176
185
|
return r, p
|
|
177
|
-
|
|
178
|
-
class EulerNetPaper(nn.Module):
|
|
179
|
-
"""
|
|
180
|
-
Paper-aligned EulerNet core (embedding -> mapping -> L Euler layers -> linear regression).
|
|
181
|
-
"""
|
|
182
|
-
|
|
183
|
-
def __init__(
|
|
184
|
-
self,
|
|
185
|
-
*,
|
|
186
|
-
embedding_dim: int,
|
|
187
|
-
num_fields: int,
|
|
188
|
-
num_layers: int = 2,
|
|
189
|
-
num_orders: int = 8, # n in paper
|
|
190
|
-
use_implicit: bool = True,
|
|
191
|
-
norm: str | None = "ln", # None | "bn" | "ln"
|
|
192
|
-
):
|
|
193
|
-
super().__init__()
|
|
194
|
-
self.d = embedding_dim
|
|
195
|
-
self.m = num_fields
|
|
196
|
-
self.L = num_layers
|
|
197
|
-
self.n = num_orders
|
|
198
|
-
|
|
199
|
-
self.mapping = ComplexSpaceMappingPaper(embedding_dim, num_fields)
|
|
200
|
-
|
|
201
|
-
self.layers = nn.ModuleList([
|
|
202
|
-
EulerInteractionLayerPaper(
|
|
203
|
-
embedding_dim=embedding_dim,
|
|
204
|
-
num_fields=(num_fields if i == 0 else num_orders), # stack: m -> n -> n ...
|
|
205
|
-
num_orders=num_orders,
|
|
206
|
-
use_implicit=use_implicit,
|
|
207
|
-
norm=norm,
|
|
208
|
-
)
|
|
209
|
-
for i in range(num_layers)
|
|
210
|
-
])
|
|
211
|
-
|
|
212
|
-
# Output regression (Eq.16-17)
|
|
213
|
-
# After last layer: r,p are [B,n,d]. Concatenate to [B, n*d] each, then regress.
|
|
214
|
-
self.w = nn.Linear(self.n * self.d, 1, bias=False) # for real
|
|
215
|
-
self.w_im = nn.Linear(self.n * self.d, 1, bias=False) # for imag
|
|
216
|
-
|
|
217
|
-
def forward(self, field_emb: torch.Tensor) -> torch.Tensor:
|
|
218
|
-
"""
|
|
219
|
-
field_emb: [B, m, d] real embeddings e_j
|
|
220
|
-
return: logits, shape [B,1]
|
|
221
|
-
"""
|
|
222
|
-
r, p = self.mapping(field_emb) # [B,m,d]
|
|
223
|
-
|
|
224
|
-
# stack Euler interaction layers
|
|
225
|
-
for layer in self.layers:
|
|
226
|
-
r, p = layer(r, p) # -> [B,n,d]
|
|
227
|
-
|
|
228
|
-
r_flat = r.reshape(r.size(0), self.n * self.d)
|
|
229
|
-
p_flat = p.reshape(p.size(0), self.n * self.d)
|
|
230
|
-
|
|
231
|
-
z_re = self.w(r_flat)
|
|
232
|
-
z_im = self.w_im(p_flat)
|
|
233
|
-
return z_re + z_im # Eq.17 logits
|
|
234
186
|
|
|
235
187
|
|
|
236
188
|
class EulerNet(BaseModel):
|
|
@@ -313,14 +265,23 @@ class EulerNet(BaseModel):
|
|
|
313
265
|
"All interaction features must share the same embedding_dim in EulerNet."
|
|
314
266
|
)
|
|
315
267
|
|
|
316
|
-
self.
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
268
|
+
self.num_layers = num_layers
|
|
269
|
+
self.num_orders = num_orders
|
|
270
|
+
self.mapping = ComplexSpaceMapping(self.embedding_dim, self.num_fields)
|
|
271
|
+
self.layers = nn.ModuleList(
|
|
272
|
+
[
|
|
273
|
+
EulerInteractionLayer(
|
|
274
|
+
embedding_dim=self.embedding_dim,
|
|
275
|
+
num_fields=(self.num_fields if i == 0 else self.num_orders),
|
|
276
|
+
num_orders=self.num_orders,
|
|
277
|
+
use_implicit=use_implicit,
|
|
278
|
+
norm=norm,
|
|
279
|
+
)
|
|
280
|
+
for i in range(self.num_layers)
|
|
281
|
+
]
|
|
323
282
|
)
|
|
283
|
+
self.w = nn.Linear(self.num_orders * self.embedding_dim, 1, bias=False)
|
|
284
|
+
self.w_im = nn.Linear(self.num_orders * self.embedding_dim, 1, bias=False)
|
|
324
285
|
|
|
325
286
|
if self.use_linear:
|
|
326
287
|
if len(self.linear_features) == 0:
|
|
@@ -336,7 +297,7 @@ class EulerNet(BaseModel):
|
|
|
336
297
|
|
|
337
298
|
self.prediction_layer = PredictionLayer(task_type=self.task)
|
|
338
299
|
|
|
339
|
-
modules = ["
|
|
300
|
+
modules = ["mapping", "layers", "w", "w_im"]
|
|
340
301
|
if self.use_linear:
|
|
341
302
|
modules.append("linear")
|
|
342
303
|
self.register_regularization_weights(
|
|
@@ -354,7 +315,7 @@ class EulerNet(BaseModel):
|
|
|
354
315
|
field_emb = self.embedding(
|
|
355
316
|
x=x, features=self.interaction_features, squeeze_dim=False
|
|
356
317
|
)
|
|
357
|
-
y_euler = self.
|
|
318
|
+
y_euler = self.euler_forward(field_emb)
|
|
358
319
|
|
|
359
320
|
if self.use_linear and self.linear is not None:
|
|
360
321
|
linear_input = self.embedding(
|
|
@@ -363,3 +324,11 @@ class EulerNet(BaseModel):
|
|
|
363
324
|
y_euler = y_euler + self.linear(linear_input)
|
|
364
325
|
|
|
365
326
|
return self.prediction_layer(y_euler)
|
|
327
|
+
|
|
328
|
+
def euler_forward(self, field_emb: torch.Tensor) -> torch.Tensor:
|
|
329
|
+
r, p = self.mapping(field_emb)
|
|
330
|
+
for layer in self.layers:
|
|
331
|
+
r, p = layer(r, p)
|
|
332
|
+
r_flat = r.reshape(r.size(0), self.num_orders * self.embedding_dim)
|
|
333
|
+
p_flat = p.reshape(p.size(0), self.num_orders * self.embedding_dim)
|
|
334
|
+
return self.w(r_flat) + self.w_im(p_flat)
|