nextrec 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
  45. nextrec-0.1.7.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,6 @@ Reference:
6
6
  [1] Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]
7
7
  //Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
8
8
  """
9
-
10
9
  import torch
11
10
  import torch.nn as nn
12
11
  from typing import Literal
@@ -19,42 +18,40 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
19
18
  class YoutubeDNN(BaseMatchModel):
20
19
  """
21
20
  YouTube Deep Neural Network for Recommendations
22
-
21
+
23
22
  用户塔:历史行为序列 + 用户特征 -> 用户embedding
24
23
  物品塔:物品特征 -> 物品embedding
25
24
  训练:sampled softmax loss (listwise)
26
25
  """
27
-
26
+
28
27
  @property
29
28
  def model_name(self) -> str:
30
29
  return "YouTubeDNN"
31
-
32
- def __init__(
33
- self,
34
- user_dense_features: list[DenseFeature] | None = None,
35
- user_sparse_features: list[SparseFeature] | None = None,
36
- user_sequence_features: list[SequenceFeature] | None = None,
37
- item_dense_features: list[DenseFeature] | None = None,
38
- item_sparse_features: list[SparseFeature] | None = None,
39
- item_sequence_features: list[SequenceFeature] | None = None,
40
- user_dnn_hidden_units: list[int] = [256, 128, 64],
41
- item_dnn_hidden_units: list[int] = [256, 128, 64],
42
- embedding_dim: int = 64,
43
- dnn_activation: str = "relu",
44
- dnn_dropout: float = 0.0,
45
- training_mode: Literal["pointwise", "pairwise", "listwise"] = "listwise",
46
- num_negative_samples: int = 100,
47
- temperature: float = 1.0,
48
- similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
49
- device: str = "cpu",
50
- embedding_l1_reg: float = 0.0,
51
- dense_l1_reg: float = 0.0,
52
- embedding_l2_reg: float = 0.0,
53
- dense_l2_reg: float = 0.0,
54
- early_stop_patience: int = 20,
55
- model_id: str = "youtube_dnn",
56
- ):
57
-
30
+
31
+ def __init__(self,
32
+ user_dense_features: list[DenseFeature] | None = None,
33
+ user_sparse_features: list[SparseFeature] | None = None,
34
+ user_sequence_features: list[SequenceFeature] | None = None,
35
+ item_dense_features: list[DenseFeature] | None = None,
36
+ item_sparse_features: list[SparseFeature] | None = None,
37
+ item_sequence_features: list[SequenceFeature] | None = None,
38
+ user_dnn_hidden_units: list[int] = [256, 128, 64],
39
+ item_dnn_hidden_units: list[int] = [256, 128, 64],
40
+ embedding_dim: int = 64,
41
+ dnn_activation: str = 'relu',
42
+ dnn_dropout: float = 0.0,
43
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
44
+ num_negative_samples: int = 100,
45
+ temperature: float = 1.0,
46
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
47
+ device: str = 'cpu',
48
+ embedding_l1_reg: float = 0.0,
49
+ dense_l1_reg: float = 0.0,
50
+ embedding_l2_reg: float = 0.0,
51
+ dense_l2_reg: float = 0.0,
52
+ early_stop_patience: int = 20,
53
+ model_id: str = 'youtube_dnn'):
54
+
58
55
  super(YoutubeDNN, self).__init__(
59
56
  user_dense_features=user_dense_features,
60
57
  user_sparse_features=user_sparse_features,
@@ -72,13 +69,13 @@ class YoutubeDNN(BaseMatchModel):
72
69
  embedding_l2_reg=embedding_l2_reg,
73
70
  dense_l2_reg=dense_l2_reg,
74
71
  early_stop_patience=early_stop_patience,
75
- model_id=model_id,
72
+ model_id=model_id
76
73
  )
77
-
74
+
78
75
  self.embedding_dim = embedding_dim
79
76
  self.user_dnn_hidden_units = user_dnn_hidden_units
80
77
  self.item_dnn_hidden_units = item_dnn_hidden_units
81
-
78
+
82
79
  # User tower
83
80
  user_features = []
84
81
  if user_dense_features:
@@ -87,10 +84,10 @@ class YoutubeDNN(BaseMatchModel):
87
84
  user_features.extend(user_sparse_features)
88
85
  if user_sequence_features:
89
86
  user_features.extend(user_sequence_features)
90
-
87
+
91
88
  if len(user_features) > 0:
92
89
  self.user_embedding = EmbeddingLayer(user_features)
93
-
90
+
94
91
  user_input_dim = 0
95
92
  for feat in user_dense_features or []:
96
93
  user_input_dim += 1
@@ -99,16 +96,16 @@ class YoutubeDNN(BaseMatchModel):
99
96
  for feat in user_sequence_features or []:
100
97
  # 序列特征通过平均池化聚合
101
98
  user_input_dim += feat.embedding_dim
102
-
99
+
103
100
  user_dnn_units = user_dnn_hidden_units + [embedding_dim]
104
101
  self.user_dnn = MLP(
105
102
  input_dim=user_input_dim,
106
103
  dims=user_dnn_units,
107
104
  output_layer=False,
108
105
  dropout=dnn_dropout,
109
- activation=dnn_activation,
106
+ activation=dnn_activation
110
107
  )
111
-
108
+
112
109
  # Item tower
113
110
  item_features = []
114
111
  if item_dense_features:
@@ -117,10 +114,10 @@ class YoutubeDNN(BaseMatchModel):
117
114
  item_features.extend(item_sparse_features)
118
115
  if item_sequence_features:
119
116
  item_features.extend(item_sequence_features)
120
-
117
+
121
118
  if len(item_features) > 0:
122
119
  self.item_embedding = EmbeddingLayer(item_features)
123
-
120
+
124
121
  item_input_dim = 0
125
122
  for feat in item_dense_features or []:
126
123
  item_input_dim += 1
@@ -128,54 +125,48 @@ class YoutubeDNN(BaseMatchModel):
128
125
  item_input_dim += feat.embedding_dim
129
126
  for feat in item_sequence_features or []:
130
127
  item_input_dim += feat.embedding_dim
131
-
128
+
132
129
  item_dnn_units = item_dnn_hidden_units + [embedding_dim]
133
130
  self.item_dnn = MLP(
134
131
  input_dim=item_input_dim,
135
132
  dims=item_dnn_units,
136
133
  output_layer=False,
137
134
  dropout=dnn_dropout,
138
- activation=dnn_activation,
135
+ activation=dnn_activation
139
136
  )
140
-
137
+
141
138
  self._register_regularization_weights(
142
- embedding_attr="user_embedding", include_modules=["user_dnn"]
139
+ embedding_attr='user_embedding',
140
+ include_modules=['user_dnn']
143
141
  )
144
142
  self._register_regularization_weights(
145
- embedding_attr="item_embedding", include_modules=["item_dnn"]
143
+ embedding_attr='item_embedding',
144
+ include_modules=['item_dnn']
146
145
  )
147
-
146
+
148
147
  self.to(device)
149
-
148
+
150
149
  def user_tower(self, user_input: dict) -> torch.Tensor:
151
150
  """
152
151
  User tower
153
152
  处理用户历史行为序列和其他用户特征
154
153
  """
155
- all_user_features = (
156
- self.user_dense_features
157
- + self.user_sparse_features
158
- + self.user_sequence_features
159
- )
154
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
160
155
  user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
161
156
  user_emb = self.user_dnn(user_emb)
162
-
157
+
163
158
  # L2 normalization
164
159
  user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
165
-
160
+
166
161
  return user_emb
167
-
162
+
168
163
  def item_tower(self, item_input: dict) -> torch.Tensor:
169
164
  """Item tower"""
170
- all_item_features = (
171
- self.item_dense_features
172
- + self.item_sparse_features
173
- + self.item_sequence_features
174
- )
165
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
175
166
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
176
167
  item_emb = self.item_dnn(item_emb)
177
-
168
+
178
169
  # L2 normalization
179
170
  item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
180
-
171
+
181
172
  return item_emb
@@ -17,15 +17,15 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
17
  class ESMM(BaseModel):
18
18
  """
19
19
  Entire Space Multi-Task Model
20
-
20
+
21
21
  ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
22
22
  - CTR task: P(click | impression)
23
23
  - CVR task: P(conversion | click)
24
24
  - CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
25
-
25
+
26
26
  This design addresses the sample selection bias and data sparsity issues in CVR modeling.
27
27
  """
28
-
28
+
29
29
  @property
30
30
  def model_name(self):
31
31
  return "ESMM"
@@ -33,34 +33,30 @@ class ESMM(BaseModel):
33
33
  @property
34
34
  def task_type(self):
35
35
  # ESMM has fixed task types: CTR (binary) and CVR (binary)
36
- return ["binary", "binary"]
37
-
38
- def __init__(
39
- self,
40
- dense_features: list[DenseFeature],
41
- sparse_features: list[SparseFeature],
42
- sequence_features: list[SequenceFeature],
43
- ctr_params: dict,
44
- cvr_params: dict,
45
- target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
46
- task: str | list[str] = "binary",
47
- optimizer: str = "adam",
48
- optimizer_params: dict = {},
49
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
50
- device: str = "cpu",
51
- model_id: str = "baseline",
52
- embedding_l1_reg=1e-6,
53
- dense_l1_reg=1e-5,
54
- embedding_l2_reg=1e-5,
55
- dense_l2_reg=1e-4,
56
- ):
57
-
36
+ return ['binary', 'binary']
37
+
38
+ def __init__(self,
39
+ dense_features: list[DenseFeature],
40
+ sparse_features: list[SparseFeature],
41
+ sequence_features: list[SequenceFeature],
42
+ ctr_params: dict,
43
+ cvr_params: dict,
44
+ target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
45
+ task: str | list[str] = 'binary',
46
+ optimizer: str = "adam",
47
+ optimizer_params: dict = {},
48
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
49
+ device: str = 'cpu',
50
+ model_id: str = "baseline",
51
+ embedding_l1_reg=1e-6,
52
+ dense_l1_reg=1e-5,
53
+ embedding_l2_reg=1e-5,
54
+ dense_l2_reg=1e-4):
55
+
58
56
  # ESMM requires exactly 2 targets: ctr and ctcvr
59
57
  if len(target) != 2:
60
- raise ValueError(
61
- f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
62
- )
63
-
58
+ raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
59
+
64
60
  super(ESMM, self).__init__(
65
61
  dense_features=dense_features,
66
62
  sparse_features=sparse_features,
@@ -73,13 +69,13 @@ class ESMM(BaseModel):
73
69
  embedding_l2_reg=embedding_l2_reg,
74
70
  dense_l2_reg=dense_l2_reg,
75
71
  early_stop_patience=20,
76
- model_id=model_id,
72
+ model_id=model_id
77
73
  )
78
74
 
79
75
  self.loss = loss
80
76
  if self.loss is None:
81
77
  self.loss = "bce"
82
-
78
+
83
79
  # All features
84
80
  self.all_features = dense_features + sparse_features + sequence_features
85
81
 
@@ -87,48 +83,46 @@ class ESMM(BaseModel):
87
83
  self.embedding = EmbeddingLayer(features=self.all_features)
88
84
 
89
85
  # Calculate input dimension
90
- emb_dim_total = sum(
91
- [
92
- f.embedding_dim
93
- for f in self.all_features
94
- if not isinstance(f, DenseFeature)
95
- ]
96
- )
97
- dense_input_dim = sum(
98
- [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
99
- )
86
+ emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
87
+ dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
100
88
  input_dim = emb_dim_total + dense_input_dim
101
-
89
+
102
90
  # CTR tower
103
91
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
104
-
92
+
105
93
  # CVR tower
106
94
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
107
95
  self.prediction_layer = PredictionLayer(
108
- task_type=self.task_type, task_dims=[1, 1]
96
+ task_type=self.task_type,
97
+ task_dims=[1, 1]
109
98
  )
110
99
 
111
100
  # Register regularization weights
112
101
  self._register_regularization_weights(
113
- embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
102
+ embedding_attr='embedding',
103
+ include_modules=['ctr_tower', 'cvr_tower']
114
104
  )
115
105
 
116
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
106
+ self.compile(
107
+ optimizer=optimizer,
108
+ optimizer_params=optimizer_params,
109
+ loss=loss
110
+ )
117
111
 
118
112
  def forward(self, x):
119
113
  # Get all embeddings and flatten
120
114
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
121
-
115
+
122
116
  # CTR prediction: P(click | impression)
123
117
  ctr_logit = self.ctr_tower(input_flat) # [B, 1]
124
118
  cvr_logit = self.cvr_tower(input_flat) # [B, 1]
125
119
  logits = torch.cat([ctr_logit, cvr_logit], dim=1)
126
120
  preds = self.prediction_layer(logits)
127
121
  ctr, cvr = preds.chunk(2, dim=1)
128
-
122
+
129
123
  # CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
130
124
  ctcvr = ctr * cvr # [B, 1]
131
-
125
+
132
126
  # Output: [CTR, CTCVR]
133
127
  # Note: We supervise CTR with click labels and CTCVR with conversion labels
134
128
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
@@ -17,13 +17,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
17
  class MMOE(BaseModel):
18
18
  """
19
19
  Multi-gate Mixture-of-Experts
20
-
20
+
21
21
  MMOE improves upon shared-bottom architecture by using multiple expert networks
22
22
  and task-specific gating networks. Each task has its own gate that learns to
23
23
  weight the contributions of different experts, allowing for both task-specific
24
24
  and shared representations.
25
25
  """
26
-
26
+
27
27
  @property
28
28
  def model_name(self):
29
29
  return "MMOE"
@@ -31,28 +31,26 @@ class MMOE(BaseModel):
31
31
  @property
32
32
  def task_type(self):
33
33
  return self.task if isinstance(self.task, list) else [self.task]
34
-
35
- def __init__(
36
- self,
37
- dense_features: list[DenseFeature] = [],
38
- sparse_features: list[SparseFeature] = [],
39
- sequence_features: list[SequenceFeature] = [],
40
- expert_params: dict = {},
41
- num_experts: int = 3,
42
- tower_params_list: list[dict] = [],
43
- target: list[str] = [],
44
- task: str | list[str] = "binary",
45
- optimizer: str = "adam",
46
- optimizer_params: dict = {},
47
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
48
- device: str = "cpu",
49
- model_id: str = "baseline",
50
- embedding_l1_reg=1e-6,
51
- dense_l1_reg=1e-5,
52
- embedding_l2_reg=1e-5,
53
- dense_l2_reg=1e-4,
54
- ):
55
-
34
+
35
+ def __init__(self,
36
+ dense_features: list[DenseFeature]=[],
37
+ sparse_features: list[SparseFeature]=[],
38
+ sequence_features: list[SequenceFeature]=[],
39
+ expert_params: dict={},
40
+ num_experts: int=3,
41
+ tower_params_list: list[dict]=[],
42
+ target: list[str]=[],
43
+ task: str | list[str] = 'binary',
44
+ optimizer: str = "adam",
45
+ optimizer_params: dict = {},
46
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
47
+ device: str = 'cpu',
48
+ model_id: str = "baseline",
49
+ embedding_l1_reg=1e-6,
50
+ dense_l1_reg=1e-5,
51
+ embedding_l2_reg=1e-5,
52
+ dense_l2_reg=1e-4):
53
+
56
54
  super(MMOE, self).__init__(
57
55
  dense_features=dense_features,
58
56
  sparse_features=sparse_features,
@@ -65,22 +63,20 @@ class MMOE(BaseModel):
65
63
  embedding_l2_reg=embedding_l2_reg,
66
64
  dense_l2_reg=dense_l2_reg,
67
65
  early_stop_patience=20,
68
- model_id=model_id,
66
+ model_id=model_id
69
67
  )
70
68
 
71
69
  self.loss = loss
72
70
  if self.loss is None:
73
71
  self.loss = "bce"
74
-
72
+
75
73
  # Number of tasks and experts
76
74
  self.num_tasks = len(target)
77
75
  self.num_experts = num_experts
78
-
76
+
79
77
  if len(tower_params_list) != self.num_tasks:
80
- raise ValueError(
81
- f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
82
- )
83
-
78
+ raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
79
+
84
80
  # All features
85
81
  self.all_features = dense_features + sparse_features + sequence_features
86
82
 
@@ -88,83 +84,78 @@ class MMOE(BaseModel):
88
84
  self.embedding = EmbeddingLayer(features=self.all_features)
89
85
 
90
86
  # Calculate input dimension
91
- emb_dim_total = sum(
92
- [
93
- f.embedding_dim
94
- for f in self.all_features
95
- if not isinstance(f, DenseFeature)
96
- ]
97
- )
98
- dense_input_dim = sum(
99
- [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
100
- )
87
+ emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
88
+ dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
101
89
  input_dim = emb_dim_total + dense_input_dim
102
-
90
+
103
91
  # Expert networks (shared by all tasks)
104
92
  self.experts = nn.ModuleList()
105
93
  for _ in range(num_experts):
106
94
  expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
107
95
  self.experts.append(expert)
108
-
96
+
109
97
  # Get expert output dimension
110
- if "dims" in expert_params and len(expert_params["dims"]) > 0:
111
- expert_output_dim = expert_params["dims"][-1]
98
+ if 'dims' in expert_params and len(expert_params['dims']) > 0:
99
+ expert_output_dim = expert_params['dims'][-1]
112
100
  else:
113
101
  expert_output_dim = input_dim
114
-
102
+
115
103
  # Task-specific gates
116
104
  self.gates = nn.ModuleList()
117
105
  for _ in range(self.num_tasks):
118
- gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
106
+ gate = nn.Sequential(
107
+ nn.Linear(input_dim, num_experts),
108
+ nn.Softmax(dim=1)
109
+ )
119
110
  self.gates.append(gate)
120
-
111
+
121
112
  # Task-specific towers
122
113
  self.towers = nn.ModuleList()
123
114
  for tower_params in tower_params_list:
124
115
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
125
116
  self.towers.append(tower)
126
117
  self.prediction_layer = PredictionLayer(
127
- task_type=self.task_type, task_dims=[1] * self.num_tasks
118
+ task_type=self.task_type,
119
+ task_dims=[1] * self.num_tasks
128
120
  )
129
121
 
130
122
  # Register regularization weights
131
123
  self._register_regularization_weights(
132
- embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
124
+ embedding_attr='embedding',
125
+ include_modules=['experts', 'gates', 'towers']
133
126
  )
134
127
 
135
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
128
+ self.compile(
129
+ optimizer=optimizer,
130
+ optimizer_params=optimizer_params,
131
+ loss=loss
132
+ )
136
133
 
137
134
  def forward(self, x):
138
135
  # Get all embeddings and flatten
139
136
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
140
-
137
+
141
138
  # Expert outputs: [num_experts, B, expert_dim]
142
139
  expert_outputs = [expert(input_flat) for expert in self.experts]
143
- expert_outputs = torch.stack(
144
- expert_outputs, dim=0
145
- ) # [num_experts, B, expert_dim]
146
-
140
+ expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
141
+
147
142
  # Task-specific processing
148
143
  task_outputs = []
149
144
  for task_idx in range(self.num_tasks):
150
145
  # Gate weights for this task: [B, num_experts]
151
146
  gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
152
-
147
+
153
148
  # Weighted sum of expert outputs
154
149
  # gate_weights: [B, num_experts, 1]
155
150
  # expert_outputs: [num_experts, B, expert_dim]
156
151
  gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
157
- expert_outputs_t = expert_outputs.permute(
158
- 1, 0, 2
159
- ) # [B, num_experts, expert_dim]
160
- gated_output = torch.sum(
161
- gate_weights * expert_outputs_t, dim=1
162
- ) # [B, expert_dim]
163
-
152
+ expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
153
+ gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
154
+
164
155
  # Tower output
165
156
  tower_output = self.towers[task_idx](gated_output) # [B, 1]
166
157
  task_outputs.append(tower_output)
167
-
158
+
168
159
  # Stack outputs: [B, num_tasks]
169
160
  y = torch.cat(task_outputs, dim=1)
170
161
  return self.prediction_layer(y)