nextrec 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +4 -8
  3. nextrec/basic/callback.py +1 -1
  4. nextrec/basic/features.py +33 -25
  5. nextrec/basic/layers.py +164 -601
  6. nextrec/basic/loggers.py +3 -4
  7. nextrec/basic/metrics.py +39 -115
  8. nextrec/basic/model.py +248 -174
  9. nextrec/basic/session.py +1 -5
  10. nextrec/data/__init__.py +12 -0
  11. nextrec/data/data_utils.py +3 -27
  12. nextrec/data/dataloader.py +26 -34
  13. nextrec/data/preprocessor.py +2 -1
  14. nextrec/loss/listwise.py +6 -4
  15. nextrec/loss/loss_utils.py +10 -6
  16. nextrec/loss/pairwise.py +5 -3
  17. nextrec/loss/pointwise.py +7 -13
  18. nextrec/models/match/mind.py +110 -1
  19. nextrec/models/multi_task/esmm.py +46 -27
  20. nextrec/models/multi_task/mmoe.py +48 -30
  21. nextrec/models/multi_task/ple.py +156 -141
  22. nextrec/models/multi_task/poso.py +413 -0
  23. nextrec/models/multi_task/share_bottom.py +43 -26
  24. nextrec/models/ranking/__init__.py +2 -0
  25. nextrec/models/ranking/autoint.py +1 -1
  26. nextrec/models/ranking/dcn.py +20 -1
  27. nextrec/models/ranking/dcn_v2.py +84 -0
  28. nextrec/models/ranking/deepfm.py +44 -18
  29. nextrec/models/ranking/dien.py +130 -27
  30. nextrec/models/ranking/masknet.py +13 -67
  31. nextrec/models/ranking/widedeep.py +39 -18
  32. nextrec/models/ranking/xdeepfm.py +34 -1
  33. nextrec/utils/common.py +26 -1
  34. nextrec-0.3.1.dist-info/METADATA +306 -0
  35. nextrec-0.3.1.dist-info/RECORD +56 -0
  36. {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/WHEEL +1 -1
  37. nextrec-0.2.6.dist-info/METADATA +0 -281
  38. nextrec-0.2.6.dist-info/RECORD +0 -54
  39. {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py CHANGED
@@ -1,24 +1,22 @@
1
1
  """
2
2
  Layer implementations used across NextRec models.
3
3
 
4
- Date: create on 27/10/2025, update on 19/11/2025
5
- Author: Yang Zhou,zyaztec@gmail.com
4
+ Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
-
8
8
  from __future__ import annotations
9
9
 
10
10
  from itertools import combinations
11
- from typing import Iterable, Sequence, Union
11
+ from collections import OrderedDict
12
12
 
13
13
  import torch
14
14
  import torch.nn as nn
15
15
  import torch.nn.functional as F
16
16
 
17
- from nextrec.basic.activation import activation_layer
18
17
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
19
18
  from nextrec.utils.initializer import get_initializer
20
-
21
- Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
19
+ from nextrec.basic.activation import activation_layer
22
20
 
23
21
  __all__ = [
24
22
  "PredictionLayer",
@@ -30,57 +28,38 @@ __all__ = [
30
28
  "SumPooling",
31
29
  "MLP",
32
30
  "FM",
33
- "FFM",
34
- "CEN",
35
- "CIN",
36
31
  "CrossLayer",
37
- "CrossNetwork",
38
- "CrossNetV2",
39
- "CrossNetMix",
40
32
  "SENETLayer",
41
33
  "BiLinearInteractionLayer",
42
- "MultiInterestSA",
43
- "CapsuleNetwork",
44
34
  "MultiHeadSelfAttention",
45
35
  "AttentionPoolingLayer",
46
- "DynamicGRU",
47
- "AUGRU",
48
36
  ]
49
37
 
50
-
51
38
  class PredictionLayer(nn.Module):
52
39
  def __init__(
53
40
  self,
54
- task_type: Union[str, Sequence[str]] = "binary",
55
- task_dims: Union[int, Sequence[int], None] = None,
41
+ task_type: str | list[str] = "binary",
42
+ task_dims: int | list[int] | None = None,
56
43
  use_bias: bool = True,
57
44
  return_logits: bool = False,
58
45
  ):
59
46
  super().__init__()
60
-
61
47
  if isinstance(task_type, str):
62
48
  self.task_types = [task_type]
63
49
  else:
64
50
  self.task_types = list(task_type)
65
-
66
51
  if len(self.task_types) == 0:
67
52
  raise ValueError("At least one task_type must be specified.")
68
-
69
53
  if task_dims is None:
70
54
  dims = [1] * len(self.task_types)
71
55
  elif isinstance(task_dims, int):
72
56
  dims = [task_dims]
73
57
  else:
74
58
  dims = list(task_dims)
75
-
76
59
  if len(dims) not in (1, len(self.task_types)):
77
- raise ValueError(
78
- "task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
79
- )
80
-
60
+ raise ValueError("[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type.")
81
61
  if len(dims) == 1 and len(self.task_types) > 1:
82
62
  dims = dims * len(self.task_types)
83
-
84
63
  self.task_dims = dims
85
64
  self.total_dim = sum(self.task_dims)
86
65
  self.return_logits = return_logits
@@ -93,7 +72,6 @@ class PredictionLayer(nn.Module):
93
72
  raise ValueError("Each task dimension must be >= 1.")
94
73
  self._task_slices.append((start, start + dim))
95
74
  start += dim
96
-
97
75
  if use_bias:
98
76
  self.bias = nn.Parameter(torch.zeros(self.total_dim))
99
77
  else:
@@ -101,25 +79,18 @@ class PredictionLayer(nn.Module):
101
79
 
102
80
  def forward(self, x: torch.Tensor) -> torch.Tensor:
103
81
  if x.dim() == 1:
104
- x = x.unsqueeze(-1)
105
-
82
+ x = x.unsqueeze(0) # (1 * total_dim)
106
83
  if x.shape[-1] != self.total_dim:
107
- raise ValueError(
108
- f"Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
109
- )
110
-
84
+ raise ValueError(f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim}).")
111
85
  logits = x if self.bias is None else x + self.bias
112
- outputs: list[torch.Tensor] = []
113
-
86
+ outputs = []
114
87
  for task_type, (start, end) in zip(self.task_types, self._task_slices):
115
- task_logits = logits[..., start:end]
88
+ task_logits = logits[..., start:end] # Extract logits for the current task
116
89
  if self.return_logits:
117
90
  outputs.append(task_logits)
118
91
  continue
119
-
120
92
  activation = self._get_activation(task_type)
121
93
  outputs.append(activation(task_logits))
122
-
123
94
  result = torch.cat(outputs, dim=-1)
124
95
  if result.shape[-1] == 1:
125
96
  result = result.squeeze(-1)
@@ -127,17 +98,16 @@ class PredictionLayer(nn.Module):
127
98
 
128
99
  def _get_activation(self, task_type: str):
129
100
  task = task_type.lower()
130
- if task in ['binary','multiclass']:
101
+ if task == 'binary':
131
102
  return torch.sigmoid
132
- if task in ['regression']:
103
+ if task == 'regression':
133
104
  return lambda x: x
134
- if task in ['multiclass']:
105
+ if task == 'multiclass':
135
106
  return lambda x: torch.softmax(x, dim=-1)
136
- raise ValueError(f"Unsupported task_type '{task_type}'.")
137
-
107
+ raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
138
108
 
139
109
  class EmbeddingLayer(nn.Module):
140
- def __init__(self, features: Sequence[Feature]):
110
+ def __init__(self, features: list):
141
111
  super().__init__()
142
112
  self.features = list(features)
143
113
  self.embed_dict = nn.ModuleDict()
@@ -148,23 +118,22 @@ class EmbeddingLayer(nn.Module):
148
118
  if isinstance(feature, (SparseFeature, SequenceFeature)):
149
119
  if feature.embedding_name in self.embed_dict:
150
120
  continue
151
-
152
- embedding = nn.Embedding(
153
- num_embeddings=feature.vocab_size,
154
- embedding_dim=feature.embedding_dim,
155
- padding_idx=feature.padding_idx,
156
- )
157
- embedding.weight.requires_grad = feature.trainable
158
-
159
- initialization = get_initializer(
160
- init_type=feature.init_type,
161
- activation="linear",
162
- param=feature.init_params,
163
- )
164
- initialization(embedding.weight)
121
+ if getattr(feature, "pretrained_weight", None) is not None:
122
+ weight = feature.pretrained_weight # type: ignore[assignment]
123
+ if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
124
+ raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
125
+ embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
126
+ embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
127
+ else:
128
+ embedding = nn.Embedding(num_embeddings=feature.vocab_size, embedding_dim=feature.embedding_dim, padding_idx=feature.padding_idx)
129
+ embedding.weight.requires_grad = feature.trainable
130
+ initialization = get_initializer(init_type=feature.init_type, activation="linear", param=feature.init_params)
131
+ initialization(embedding.weight)
165
132
  self.embed_dict[feature.embedding_name] = embedding
166
-
167
133
  elif isinstance(feature, DenseFeature):
134
+ if not feature.use_embedding:
135
+ self.dense_input_dims[feature.name] = max(int(getattr(feature, "input_dim", 1)), 1)
136
+ continue
168
137
  if feature.name in self.dense_transforms:
169
138
  continue
170
139
  in_dim = max(int(getattr(feature, "input_dim", 1)), 1)
@@ -174,15 +143,14 @@ class EmbeddingLayer(nn.Module):
174
143
  nn.init.zeros_(dense_linear.bias)
175
144
  self.dense_transforms[feature.name] = dense_linear
176
145
  self.dense_input_dims[feature.name] = in_dim
177
-
178
146
  else:
179
- raise TypeError(f"Unsupported feature type: {type(feature)}")
147
+ raise TypeError(f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}")
180
148
  self.output_dim = self._compute_output_dim()
181
149
 
182
150
  def forward(
183
151
  self,
184
152
  x: dict[str, torch.Tensor],
185
- features: Sequence[Feature],
153
+ features: list[object],
186
154
  squeeze_dim: bool = False,
187
155
  ) -> torch.Tensor:
188
156
  sparse_embeds: list[torch.Tensor] = []
@@ -208,8 +176,7 @@ class EmbeddingLayer(nn.Module):
208
176
  elif feature.combiner == "concat":
209
177
  pooling_layer = ConcatPooling()
210
178
  else:
211
- raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
212
-
179
+ raise ValueError(f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}")
213
180
  feature_mask = InputMask()(x, feature, seq_input)
214
181
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
215
182
 
@@ -223,107 +190,116 @@ class EmbeddingLayer(nn.Module):
223
190
  pieces.append(torch.cat(flattened_sparse, dim=1))
224
191
  if dense_embeds:
225
192
  pieces.append(torch.cat(dense_embeds, dim=1))
226
-
227
193
  if not pieces:
228
- raise ValueError("No input features found for EmbeddingLayer.")
229
-
194
+ raise ValueError("[EmbeddingLayer Error]: No input features found for EmbeddingLayer.")
230
195
  return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
231
-
196
+
232
197
  # squeeze_dim=False requires embeddings with identical last dimension
233
198
  output_embeddings = list(sparse_embeds)
234
199
  if dense_embeds:
235
- target_dim = None
236
200
  if output_embeddings:
237
201
  target_dim = output_embeddings[0].shape[-1]
238
- elif len({emb.shape[-1] for emb in dense_embeds}) == 1:
239
- target_dim = dense_embeds[0].shape[-1]
240
-
241
- if target_dim is not None:
242
- aligned_dense = [
243
- emb.unsqueeze(1) for emb in dense_embeds if emb.shape[-1] == target_dim
244
- ]
245
- output_embeddings.extend(aligned_dense)
246
-
202
+ for emb in dense_embeds:
203
+ if emb.shape[-1] != target_dim:
204
+ raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}.")
205
+ output_embeddings.extend(emb.unsqueeze(1) for emb in dense_embeds)
206
+ else:
207
+ dims = {emb.shape[-1] for emb in dense_embeds}
208
+ if len(dims) != 1:
209
+ raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}.")
210
+ output_embeddings = [emb.unsqueeze(1) for emb in dense_embeds]
247
211
  if not output_embeddings:
248
- raise ValueError(
249
- "squeeze_dim=False requires at least one sparse/sequence feature or "
250
- "dense features with identical projected dimensions."
251
- )
252
-
212
+ raise ValueError("[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions.")
253
213
  return torch.cat(output_embeddings, dim=1)
254
214
 
255
215
  def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
256
216
  if feature.name not in x:
257
- raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
258
-
217
+ raise KeyError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input.")
259
218
  value = x[feature.name].float()
260
219
  if value.dim() == 1:
261
220
  value = value.unsqueeze(-1)
262
221
  else:
263
222
  value = value.view(value.size(0), -1)
264
-
265
- dense_layer = self.dense_transforms[feature.name]
266
- expected_in_dim = self.dense_input_dims[feature.name]
223
+ expected_in_dim = self.dense_input_dims.get(feature.name, max(int(getattr(feature, "input_dim", 1)), 1))
267
224
  if value.shape[1] != expected_in_dim:
268
- raise ValueError(
269
- f"Dense feature '{feature.name}' expects {expected_in_dim} inputs but "
270
- f"got {value.shape[1]}."
271
- )
272
-
225
+ raise ValueError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}.")
226
+ if not feature.use_embedding:
227
+ return value
228
+ dense_layer = self.dense_transforms[feature.name]
273
229
  return dense_layer(value)
274
230
 
275
- def _compute_output_dim(self):
276
- return
231
+ def _compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
232
+ """
233
+ Compute flattened embedding dimension for provided features or all tracked features.
234
+ Deduplicates by feature name to avoid double-counting shared embeddings.
235
+ """
236
+ candidates = list(features) if features is not None else self.features
237
+ unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
238
+ dim = 0
239
+ for feat in unique_feats.values():
240
+ if isinstance(feat, DenseFeature):
241
+ in_dim = max(int(getattr(feat, "input_dim", 1)), 1)
242
+ emb_dim = getattr(feat, "embedding_dim", None)
243
+ out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
244
+ dim += out_dim
245
+ elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
246
+ dim += feat.embedding_dim * feat.max_len
247
+ else:
248
+ dim += feat.embedding_dim # type: ignore[assignment]
249
+ return dim
250
+
251
+ def get_input_dim(self, features: list[object] | None = None) -> int:
252
+ return self._compute_output_dim(features) # type: ignore[assignment]
253
+
254
+ @property
255
+ def input_dim(self) -> int:
256
+ return self.output_dim
277
257
 
278
258
  class InputMask(nn.Module):
279
259
  """Utility module to build sequence masks for pooling layers."""
280
-
281
260
  def __init__(self):
282
261
  super().__init__()
283
262
 
284
- def forward(self, x, fea, seq_tensor=None):
285
- values = seq_tensor if seq_tensor is not None else x[fea.name]
286
- if fea.padding_idx is not None:
287
- mask = (values.long() != fea.padding_idx)
263
+ def forward(self, x: dict[str, torch.Tensor], feature: SequenceFeature, seq_tensor: torch.Tensor | None = None):
264
+ values = seq_tensor if seq_tensor is not None else x[feature.name]
265
+ if feature.padding_idx is not None:
266
+ mask = (values.long() != feature.padding_idx)
288
267
  else:
289
268
  mask = (values.long() != 0)
290
269
  if mask.dim() == 1:
291
270
  mask = mask.unsqueeze(-1)
292
271
  return mask.unsqueeze(1).float()
293
272
 
294
-
295
273
  class LR(nn.Module):
296
274
  """Wide component from Wide&Deep (Cheng et al., 2016)."""
297
-
298
- def __init__(self, input_dim, sigmoid=False):
275
+ def __init__(
276
+ self,
277
+ input_dim: int,
278
+ sigmoid: bool = False):
299
279
  super().__init__()
300
280
  self.sigmoid = sigmoid
301
281
  self.fc = nn.Linear(input_dim, 1, bias=True)
302
282
 
303
- def forward(self, x):
283
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
304
284
  if self.sigmoid:
305
285
  return torch.sigmoid(self.fc(x))
306
286
  else:
307
287
  return self.fc(x)
308
288
 
309
-
310
289
  class ConcatPooling(nn.Module):
311
290
  """Concatenates sequence embeddings along the temporal dimension."""
312
-
313
291
  def __init__(self):
314
292
  super().__init__()
315
293
 
316
- def forward(self, x, mask=None):
294
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
317
295
  return x.flatten(start_dim=1, end_dim=2)
318
296
 
319
-
320
297
  class AveragePooling(nn.Module):
321
298
  """Mean pooling with optional padding mask."""
322
-
323
299
  def __init__(self):
324
300
  super().__init__()
325
301
 
326
- def forward(self, x, mask=None):
302
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
327
303
  if mask is None:
328
304
  return torch.mean(x, dim=1)
329
305
  else:
@@ -331,24 +307,26 @@ class AveragePooling(nn.Module):
331
307
  non_padding_length = mask.sum(dim=-1)
332
308
  return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
333
309
 
334
-
335
310
  class SumPooling(nn.Module):
336
311
  """Sum pooling with optional padding mask."""
337
-
338
312
  def __init__(self):
339
313
  super().__init__()
340
314
 
341
- def forward(self, x, mask=None):
315
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
342
316
  if mask is None:
343
317
  return torch.sum(x, dim=1)
344
318
  else:
345
319
  return torch.bmm(mask, x).squeeze(1)
346
320
 
347
-
348
321
  class MLP(nn.Module):
349
322
  """Stacked fully connected layers used in the deep component."""
350
-
351
- def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
323
+ def __init__(
324
+ self,
325
+ input_dim: int,
326
+ output_layer: bool = True,
327
+ dims: list[int] | None = None,
328
+ dropout: float = 0.0,
329
+ activation: str = "relu"):
352
330
  super().__init__()
353
331
  if dims is None:
354
332
  dims = []
@@ -366,15 +344,13 @@ class MLP(nn.Module):
366
344
  def forward(self, x):
367
345
  return self.mlp(x)
368
346
 
369
-
370
347
  class FM(nn.Module):
371
348
  """Factorization Machine (Rendle, 2010) second-order interaction term."""
372
-
373
- def __init__(self, reduce_sum=True):
349
+ def __init__(self, reduce_sum: bool = True):
374
350
  super().__init__()
375
351
  self.reduce_sum = reduce_sum
376
352
 
377
- def forward(self, x):
353
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
378
354
  square_of_sum = torch.sum(x, dim=1)**2
379
355
  sum_of_square = torch.sum(x**2, dim=1)
380
356
  ix = square_of_sum - sum_of_square
@@ -382,157 +358,30 @@ class FM(nn.Module):
382
358
  ix = torch.sum(ix, dim=1, keepdim=True)
383
359
  return 0.5 * ix
384
360
 
385
-
386
- class CIN(nn.Module):
387
- """Compressed Interaction Network from xDeepFM (Lian et al., 2018)."""
388
-
389
- def __init__(self, input_dim, cin_size, split_half=True):
390
- super().__init__()
391
- self.num_layers = len(cin_size)
392
- self.split_half = split_half
393
- self.conv_layers = torch.nn.ModuleList()
394
- prev_dim, fc_input_dim = input_dim, 0
395
- for i in range(self.num_layers):
396
- cross_layer_size = cin_size[i]
397
- self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
398
- if self.split_half and i != self.num_layers - 1:
399
- cross_layer_size //= 2
400
- prev_dim = cross_layer_size
401
- fc_input_dim += prev_dim
402
- self.fc = torch.nn.Linear(fc_input_dim, 1)
403
-
404
- def forward(self, x):
405
- xs = list()
406
- x0, h = x.unsqueeze(2), x
407
- for i in range(self.num_layers):
408
- x = x0 * h.unsqueeze(1)
409
- batch_size, f0_dim, fin_dim, embed_dim = x.shape
410
- x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
411
- x = F.relu(self.conv_layers[i](x))
412
- if self.split_half and i != self.num_layers - 1:
413
- x, h = torch.split(x, x.shape[1] // 2, dim=1)
414
- else:
415
- h = x
416
- xs.append(x)
417
- return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
418
-
419
361
  class CrossLayer(nn.Module):
420
362
  """Single cross layer used in DCN (Wang et al., 2017)."""
421
-
422
- def __init__(self, input_dim):
363
+ def __init__(self, input_dim: int):
423
364
  super(CrossLayer, self).__init__()
424
365
  self.w = torch.nn.Linear(input_dim, 1, bias=False)
425
366
  self.b = torch.nn.Parameter(torch.zeros(input_dim))
426
367
 
427
- def forward(self, x_0, x_i):
368
+ def forward(self, x_0: torch.Tensor, x_i: torch.Tensor) -> torch.Tensor:
428
369
  x = self.w(x_i) * x_0 + self.b
429
370
  return x
430
371
 
431
-
432
- class CrossNetwork(nn.Module):
433
- """Stacked Cross Layers from DCN (Wang et al., 2017)."""
434
-
435
- def __init__(self, input_dim, num_layers):
436
- super().__init__()
437
- self.num_layers = num_layers
438
- self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
439
- self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
440
-
441
- def forward(self, x):
442
- """
443
- :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
444
- """
445
- x0 = x
446
- for i in range(self.num_layers):
447
- xw = self.w[i](x)
448
- x = x0 * xw + self.b[i] + x
449
- return x
450
-
451
- class CrossNetV2(nn.Module):
452
- """Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
453
- def __init__(self, input_dim, num_layers):
454
- super().__init__()
455
- self.num_layers = num_layers
456
- self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
457
- self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
458
-
459
-
460
- def forward(self, x):
461
- x0 = x
462
- for i in range(self.num_layers):
463
- x =x0*self.w[i](x) + self.b[i] + x
464
- return x
465
-
466
- class CrossNetMix(nn.Module):
467
- """Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
468
-
469
- def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
470
- super(CrossNetMix, self).__init__()
471
- self.num_layers = num_layers
472
- self.num_experts = num_experts
473
-
474
- # U: (input_dim, low_rank)
475
- self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
476
- torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
477
- # V: (input_dim, low_rank)
478
- self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
479
- torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
480
- # C: (low_rank, low_rank)
481
- self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
482
- torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
483
- self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
484
-
485
- self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
486
- torch.empty(input_dim, 1))) for i in range(self.num_layers)])
487
-
488
- def forward(self, x):
489
- x_0 = x.unsqueeze(2) # (bs, in_features, 1)
490
- x_l = x_0
491
- for i in range(self.num_layers):
492
- output_of_experts = []
493
- gating_score_experts = []
494
- for expert_id in range(self.num_experts):
495
- # (1) G(x_l)
496
- # compute the gating score by x_l
497
- gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
498
-
499
- # (2) E(x_l)
500
- # project the input x_l to $\mathbb{R}^{r}$
501
- v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
502
-
503
- # nonlinear activation in low rank space
504
- v_x = torch.tanh(v_x)
505
- v_x = torch.matmul(self.c_list[i][expert_id], v_x)
506
- v_x = torch.tanh(v_x)
507
-
508
- # project back to $\mathbb{R}^{d}$
509
- uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
510
-
511
- dot_ = uv_x + self.bias[i]
512
- dot_ = x_0 * dot_ # Hadamard-product
513
-
514
- output_of_experts.append(dot_.squeeze(2))
515
-
516
- # (3) mixture of low-rank experts
517
- output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
518
- gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
519
- moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
520
- x_l = moe_out + x_l # (bs, in_features, 1)
521
-
522
- x_l = x_l.squeeze() # (bs, in_features)
523
- return x_l
524
-
525
372
  class SENETLayer(nn.Module):
526
373
  """Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
527
-
528
- def __init__(self, num_fields, reduction_ratio=3):
374
+ def __init__(
375
+ self,
376
+ num_fields: int,
377
+ reduction_ratio: int = 3):
529
378
  super(SENETLayer, self).__init__()
530
379
  reduced_size = max(1, int(num_fields/ reduction_ratio))
531
380
  self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
532
381
  nn.ReLU(),
533
382
  nn.Linear(reduced_size, num_fields, bias=False),
534
383
  nn.ReLU())
535
- def forward(self, x):
384
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
536
385
  z = torch.mean(x, dim=-1, out=None)
537
386
  a = self.mlp(z)
538
387
  v = x*a.unsqueeze(-1)
@@ -540,8 +389,11 @@ class SENETLayer(nn.Module):
540
389
 
541
390
  class BiLinearInteractionLayer(nn.Module):
542
391
  """Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
543
-
544
- def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
392
+ def __init__(
393
+ self,
394
+ input_dim: int,
395
+ num_fields: int,
396
+ bilinear_type: str = "field_interaction"):
545
397
  super(BiLinearInteractionLayer, self).__init__()
546
398
  self.bilinear_type = bilinear_type
547
399
  if self.bilinear_type == "field_all":
@@ -553,263 +405,96 @@ class BiLinearInteractionLayer(nn.Module):
553
405
  else:
554
406
  raise NotImplementedError()
555
407
 
556
- def forward(self, x):
408
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
557
409
  feature_emb = torch.split(x, 1, dim=1)
558
410
  if self.bilinear_type == "field_all":
559
411
  bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
560
412
  elif self.bilinear_type == "field_each":
561
- bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
413
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
562
414
  elif self.bilinear_type == "field_interaction":
563
- bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
415
+ bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
564
416
  return torch.cat(bilinear_list, dim=1)
565
417
 
566
-
567
- class MultiInterestSA(nn.Module):
568
- """Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
569
-
570
- def __init__(self, embedding_dim, interest_num, hidden_dim=None):
571
- super(MultiInterestSA, self).__init__()
572
- self.embedding_dim = embedding_dim
573
- self.interest_num = interest_num
574
- if hidden_dim == None:
575
- self.hidden_dim = self.embedding_dim * 4
576
- self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
577
- self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
578
- self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
579
-
580
- def forward(self, seq_emb, mask=None):
581
- H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
582
- if mask != None:
583
- A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
584
- A = F.softmax(A, dim=1)
585
- else:
586
- A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
587
- A = A.permute(0, 2, 1)
588
- multi_interest_emb = torch.matmul(A, seq_emb)
589
- return multi_interest_emb
590
-
591
-
592
- class CapsuleNetwork(nn.Module):
593
- """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
594
-
595
- def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
596
- super(CapsuleNetwork, self).__init__()
597
- self.embedding_dim = embedding_dim # h
598
- self.seq_len = seq_len # s
599
- self.bilinear_type = bilinear_type
600
- self.interest_num = interest_num
601
- self.routing_times = routing_times
602
-
603
- self.relu_layer = relu_layer
604
- self.stop_grad = True
605
- self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
606
- if self.bilinear_type == 0: # MIND
607
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
608
- elif self.bilinear_type == 1:
609
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
610
- else:
611
- self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
612
- nn.init.xavier_uniform_(self.w)
613
-
614
- def forward(self, item_eb, mask):
615
- if self.bilinear_type == 0:
616
- item_eb_hat = self.linear(item_eb)
617
- item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
618
- elif self.bilinear_type == 1:
619
- item_eb_hat = self.linear(item_eb)
620
- else:
621
- u = torch.unsqueeze(item_eb, dim=2)
622
- item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
623
-
624
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
625
- item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
626
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
627
-
628
- if self.stop_grad:
629
- item_eb_hat_iter = item_eb_hat.detach()
630
- else:
631
- item_eb_hat_iter = item_eb_hat
632
-
633
- if self.bilinear_type > 0:
634
- capsule_weight = torch.zeros(item_eb_hat.shape[0],
635
- self.interest_num,
636
- self.seq_len,
637
- device=item_eb.device,
638
- requires_grad=False)
639
- else:
640
- capsule_weight = torch.randn(item_eb_hat.shape[0],
641
- self.interest_num,
642
- self.seq_len,
643
- device=item_eb.device,
644
- requires_grad=False)
645
-
646
- for i in range(self.routing_times): # 动态路由传播3次
647
- atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
648
- paddings = torch.zeros_like(atten_mask, dtype=torch.float)
649
-
650
- capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
651
- capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
652
- capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
653
-
654
- if i < 2:
655
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
656
- cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
657
- scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
658
- interest_capsule = scalar_factor * interest_capsule
659
-
660
- delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
661
- delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
662
- capsule_weight = capsule_weight + delta_weight
663
- else:
664
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
665
- cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
666
- scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
667
- interest_capsule = scalar_factor * interest_capsule
668
-
669
- interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
670
-
671
- if self.relu_layer:
672
- interest_capsule = self.relu(interest_capsule)
673
-
674
- return interest_capsule
675
-
676
-
677
- class FFM(nn.Module):
678
- """Field-aware Factorization Machine (Juan et al., 2016)."""
679
-
680
- def __init__(self, num_fields, reduce_sum=True):
681
- super().__init__()
682
- self.num_fields = num_fields
683
- self.reduce_sum = reduce_sum
684
-
685
- def forward(self, x):
686
- # compute (non-redundant) second order field-aware feature crossings
687
- crossed_embeddings = []
688
- for i in range(self.num_fields-1):
689
- for j in range(i+1, self.num_fields):
690
- crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
691
- crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
692
-
693
- # if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
694
- if self.reduce_sum:
695
- crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
696
- return crossed_embeddings
697
-
698
-
699
- class CEN(nn.Module):
700
- """Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
701
-
702
- def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
703
- super().__init__()
704
-
705
- # convolution weight (Eq.7 FAT-DeepFFM)
706
- self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
707
-
708
- # two FC layers that computes the field attention
709
- self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
710
-
711
-
712
- def forward(self, em):
713
- # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
714
- d = F.relu((self.u.squeeze(0) * em).sum(-1))
715
-
716
- # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
717
- s = self.mlp_att(d)
718
-
719
- # rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
720
- aem = s.unsqueeze(-1) * em
721
- return aem.flatten(start_dim=1)
722
-
723
-
724
418
  class MultiHeadSelfAttention(nn.Module):
725
419
  """Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
726
-
727
- def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
420
+ def __init__(
421
+ self,
422
+ embedding_dim: int,
423
+ num_heads: int = 2,
424
+ dropout: float = 0.0,
425
+ use_residual: bool = True):
728
426
  super().__init__()
729
427
  if embedding_dim % num_heads != 0:
730
- raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
731
-
428
+ raise ValueError(f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
732
429
  self.embedding_dim = embedding_dim
733
430
  self.num_heads = num_heads
734
431
  self.head_dim = embedding_dim // num_heads
735
432
  self.use_residual = use_residual
736
-
737
433
  self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
738
434
  self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
739
435
  self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
740
-
741
436
  if self.use_residual:
742
437
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
743
-
744
438
  self.dropout = nn.Dropout(dropout)
745
439
 
746
- def forward(self, x):
440
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
747
441
  """
748
442
  Args:
749
- x: [batch_size, num_fields, embedding_dim]
443
+ x (torch.Tensor): Tensor of shape (batch_size, num_fields, embedding_dim)
444
+
750
445
  Returns:
751
- output: [batch_size, num_fields, embedding_dim]
446
+ torch.Tensor: Output tensor of shape (batch_size, num_fields, embedding_dim)
752
447
  """
753
448
  batch_size, num_fields, _ = x.shape
754
-
755
- # Linear projections
756
449
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
757
450
  K = self.W_K(x)
758
451
  V = self.W_V(x)
759
-
760
452
  # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
761
453
  Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
762
454
  K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
763
455
  V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
764
-
765
456
  # Attention scores
766
457
  scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
767
458
  attention_weights = F.softmax(scores, dim=-1)
768
459
  attention_weights = self.dropout(attention_weights)
769
-
770
- # Apply attention to values
771
460
  attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
772
-
773
461
  # Concatenate heads
774
462
  attention_output = attention_output.transpose(1, 2).contiguous()
775
463
  attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
776
-
777
464
  # Residual connection
778
465
  if self.use_residual:
779
466
  output = attention_output + self.W_Res(x)
780
467
  else:
781
468
  output = attention_output
782
-
783
469
  output = F.relu(output)
784
-
785
470
  return output
786
471
 
787
-
788
472
  class AttentionPoolingLayer(nn.Module):
789
473
  """
790
474
  Attention pooling layer for DIN/DIEN
791
475
  Computes attention weights between query (candidate item) and keys (user behavior sequence)
792
476
  """
793
-
794
- def __init__(self, embedding_dim, hidden_units=[80, 40], activation='sigmoid', use_softmax=True):
477
+ def __init__(
478
+ self,
479
+ embedding_dim: int,
480
+ hidden_units: list = [80, 40],
481
+ activation: str ='sigmoid',
482
+ use_softmax: bool = True):
795
483
  super().__init__()
796
484
  self.embedding_dim = embedding_dim
797
485
  self.use_softmax = use_softmax
798
-
799
486
  # Build attention network
800
487
  # Input: [query, key, query-key, query*key] -> 4 * embedding_dim
801
488
  input_dim = 4 * embedding_dim
802
489
  layers = []
803
-
804
490
  for hidden_unit in hidden_units:
805
491
  layers.append(nn.Linear(input_dim, hidden_unit))
806
492
  layers.append(activation_layer(activation))
807
493
  input_dim = hidden_unit
808
-
809
494
  layers.append(nn.Linear(input_dim, 1))
810
495
  self.attention_net = nn.Sequential(*layers)
811
496
 
812
- def forward(self, query, keys, keys_length=None, mask=None):
497
+ def forward(self, query: torch.Tensor, keys: torch.Tensor, keys_length: torch.Tensor | None = None, mask: torch.Tensor | None = None):
813
498
  """
814
499
  Args:
815
500
  query: [batch_size, embedding_dim] - candidate item embedding
@@ -819,162 +504,40 @@ class AttentionPoolingLayer(nn.Module):
819
504
  Returns:
820
505
  output: [batch_size, embedding_dim] - attention pooled representation
821
506
  """
822
- batch_size, seq_len, emb_dim = keys.shape
823
-
824
- # Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
825
- query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
826
-
827
- # Compute attention features: [query, key, query-key, query*key]
828
- attention_input = torch.cat([
829
- query_expanded,
830
- keys,
831
- query_expanded - keys,
832
- query_expanded * keys
833
- ], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
834
-
835
- # Compute attention scores
836
- attention_scores = self.attention_net(attention_input) # [batch_size, seq_len, 1]
837
-
838
- # Apply mask if provided
507
+ batch_size, sequence_length, embedding_dim = keys.shape
508
+ assert query.shape == (batch_size, embedding_dim), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
509
+ if mask is None and keys_length is not None:
510
+ # keys_length: (batch_size,)
511
+ device = keys.device
512
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) # (1, sequence_length)
513
+ mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
514
+ if mask is not None:
515
+ if mask.dim() == 2:
516
+ # (B, L)
517
+ mask = mask.unsqueeze(-1)
518
+ elif mask.dim() == 3 and mask.shape[1] == 1 and mask.shape[2] == sequence_length:
519
+ # (B, 1, L) -> (B, L, 1)
520
+ mask = mask.transpose(1, 2)
521
+ elif mask.dim() == 3 and mask.shape[1] == sequence_length and mask.shape[2] == 1:
522
+ pass
523
+ else:
524
+ raise ValueError(f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}")
525
+ mask = mask.to(keys.dtype)
526
+ # Expand query to (B, L, D)
527
+ query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
528
+ # [query, key, query-key, query*key] -> (B, L, 4D)
529
+ attention_input = torch.cat([query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1,)
530
+ attention_scores = self.attention_net(attention_input)
839
531
  if mask is not None:
840
532
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
841
-
842
- # Apply softmax to get attention weights
533
+ # Get attention weights
843
534
  if self.use_softmax:
844
- attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
535
+ # softmax over seq_len
536
+ attention_weights = F.softmax(attention_scores, dim=1) # (B, L, 1)
845
537
  else:
846
- attention_weights = attention_scores
847
-
848
- # Weighted sum of keys
849
- output = torch.sum(attention_weights * keys, dim=1) # [batch_size, embedding_dim]
850
-
538
+ attention_weights = torch.sigmoid(attention_scores)
539
+ if mask is not None:
540
+ attention_weights = attention_weights * mask
541
+ # Weighted sum over keys: (B, L, 1) * (B, L, D) -> (B, D)
542
+ output = torch.sum(attention_weights * keys, dim=1)
851
543
  return output
852
-
853
-
854
- class DynamicGRU(nn.Module):
855
- """Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
856
- """
857
- GRU with dynamic routing for DIEN
858
- """
859
-
860
- def __init__(self, input_size, hidden_size, bias=True):
861
- super().__init__()
862
- self.input_size = input_size
863
- self.hidden_size = hidden_size
864
-
865
- # GRU parameters
866
- self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
867
- self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
868
- if bias:
869
- self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
870
- self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
871
- else:
872
- self.register_parameter('bias_ih', None)
873
- self.register_parameter('bias_hh', None)
874
-
875
- self.reset_parameters()
876
-
877
- def reset_parameters(self):
878
- std = 1.0 / (self.hidden_size) ** 0.5
879
- for weight in self.parameters():
880
- weight.data.uniform_(-std, std)
881
-
882
- def forward(self, x, att_scores=None):
883
- """
884
- Args:
885
- x: [batch_size, seq_len, input_size]
886
- att_scores: [batch_size, seq_len] - attention scores for auxiliary loss
887
- Returns:
888
- output: [batch_size, seq_len, hidden_size]
889
- hidden: [batch_size, hidden_size] - final hidden state
890
- """
891
- batch_size, seq_len, _ = x.shape
892
-
893
- # Initialize hidden state
894
- h = torch.zeros(batch_size, self.hidden_size, device=x.device)
895
-
896
- outputs = []
897
- for t in range(seq_len):
898
- x_t = x[:, t, :] # [batch_size, input_size]
899
-
900
- # GRU computation
901
- gi = F.linear(x_t, self.weight_ih, self.bias_ih)
902
- gh = F.linear(h, self.weight_hh, self.bias_hh)
903
- i_r, i_i, i_n = gi.chunk(3, 1)
904
- h_r, h_i, h_n = gh.chunk(3, 1)
905
-
906
- resetgate = torch.sigmoid(i_r + h_r)
907
- inputgate = torch.sigmoid(i_i + h_i)
908
- newgate = torch.tanh(i_n + resetgate * h_n)
909
- h = newgate + inputgate * (h - newgate)
910
-
911
- outputs.append(h.unsqueeze(1))
912
-
913
- output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
914
-
915
- return output, h
916
-
917
-
918
- class AUGRU(nn.Module):
919
- """Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
920
- """
921
- Attention-based GRU for DIEN
922
- Uses attention scores to weight the update of hidden states
923
- """
924
-
925
- def __init__(self, input_size, hidden_size, bias=True):
926
- super().__init__()
927
- self.input_size = input_size
928
- self.hidden_size = hidden_size
929
-
930
- self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
931
- self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
932
- if bias:
933
- self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
934
- self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
935
- else:
936
- self.register_parameter('bias_ih', None)
937
- self.register_parameter('bias_hh', None)
938
-
939
- self.reset_parameters()
940
-
941
- def reset_parameters(self):
942
- std = 1.0 / (self.hidden_size) ** 0.5
943
- for weight in self.parameters():
944
- weight.data.uniform_(-std, std)
945
-
946
- def forward(self, x, att_scores):
947
- """
948
- Args:
949
- x: [batch_size, seq_len, input_size]
950
- att_scores: [batch_size, seq_len, 1] - attention scores
951
- Returns:
952
- output: [batch_size, seq_len, hidden_size]
953
- hidden: [batch_size, hidden_size] - final hidden state
954
- """
955
- batch_size, seq_len, _ = x.shape
956
-
957
- h = torch.zeros(batch_size, self.hidden_size, device=x.device)
958
-
959
- outputs = []
960
- for t in range(seq_len):
961
- x_t = x[:, t, :] # [batch_size, input_size]
962
- att_t = att_scores[:, t, :] # [batch_size, 1]
963
-
964
- gi = F.linear(x_t, self.weight_ih, self.bias_ih)
965
- gh = F.linear(h, self.weight_hh, self.bias_hh)
966
- i_r, i_i, i_n = gi.chunk(3, 1)
967
- h_r, h_i, h_n = gh.chunk(3, 1)
968
-
969
- resetgate = torch.sigmoid(i_r + h_r)
970
- inputgate = torch.sigmoid(i_i + h_i)
971
- newgate = torch.tanh(i_n + resetgate * h_n)
972
-
973
- # Use attention score to control update
974
- h = (1 - att_t) * h + att_t * newgate
975
-
976
- outputs.append(h.unsqueeze(1))
977
-
978
- output = torch.cat(outputs, dim=1)
979
-
980
- return output, h