nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ Reference:
6
6
  [1] Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]
7
7
  //Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
8
8
  """
9
+
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  from typing import Literal
@@ -18,40 +19,42 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
18
19
  class YoutubeDNN(BaseMatchModel):
19
20
  """
20
21
  YouTube Deep Neural Network for Recommendations
21
-
22
+
22
23
  用户塔:历史行为序列 + 用户特征 -> 用户embedding
23
24
  物品塔:物品特征 -> 物品embedding
24
25
  训练:sampled softmax loss (listwise)
25
26
  """
26
-
27
+
27
28
  @property
28
29
  def model_name(self) -> str:
29
30
  return "YouTubeDNN"
30
-
31
- def __init__(self,
32
- user_dense_features: list[DenseFeature] | None = None,
33
- user_sparse_features: list[SparseFeature] | None = None,
34
- user_sequence_features: list[SequenceFeature] | None = None,
35
- item_dense_features: list[DenseFeature] | None = None,
36
- item_sparse_features: list[SparseFeature] | None = None,
37
- item_sequence_features: list[SequenceFeature] | None = None,
38
- user_dnn_hidden_units: list[int] = [256, 128, 64],
39
- item_dnn_hidden_units: list[int] = [256, 128, 64],
40
- embedding_dim: int = 64,
41
- dnn_activation: str = 'relu',
42
- dnn_dropout: float = 0.0,
43
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
44
- num_negative_samples: int = 100,
45
- temperature: float = 1.0,
46
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
47
- device: str = 'cpu',
48
- embedding_l1_reg: float = 0.0,
49
- dense_l1_reg: float = 0.0,
50
- embedding_l2_reg: float = 0.0,
51
- dense_l2_reg: float = 0.0,
52
- early_stop_patience: int = 20,
53
- model_id: str = 'youtube_dnn'):
54
-
31
+
32
+ def __init__(
33
+ self,
34
+ user_dense_features: list[DenseFeature] | None = None,
35
+ user_sparse_features: list[SparseFeature] | None = None,
36
+ user_sequence_features: list[SequenceFeature] | None = None,
37
+ item_dense_features: list[DenseFeature] | None = None,
38
+ item_sparse_features: list[SparseFeature] | None = None,
39
+ item_sequence_features: list[SequenceFeature] | None = None,
40
+ user_dnn_hidden_units: list[int] = [256, 128, 64],
41
+ item_dnn_hidden_units: list[int] = [256, 128, 64],
42
+ embedding_dim: int = 64,
43
+ dnn_activation: str = "relu",
44
+ dnn_dropout: float = 0.0,
45
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "listwise",
46
+ num_negative_samples: int = 100,
47
+ temperature: float = 1.0,
48
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
49
+ device: str = "cpu",
50
+ embedding_l1_reg: float = 0.0,
51
+ dense_l1_reg: float = 0.0,
52
+ embedding_l2_reg: float = 0.0,
53
+ dense_l2_reg: float = 0.0,
54
+ early_stop_patience: int = 20,
55
+ model_id: str = "youtube_dnn",
56
+ ):
57
+
55
58
  super(YoutubeDNN, self).__init__(
56
59
  user_dense_features=user_dense_features,
57
60
  user_sparse_features=user_sparse_features,
@@ -69,13 +72,13 @@ class YoutubeDNN(BaseMatchModel):
69
72
  embedding_l2_reg=embedding_l2_reg,
70
73
  dense_l2_reg=dense_l2_reg,
71
74
  early_stop_patience=early_stop_patience,
72
- model_id=model_id
75
+ model_id=model_id,
73
76
  )
74
-
77
+
75
78
  self.embedding_dim = embedding_dim
76
79
  self.user_dnn_hidden_units = user_dnn_hidden_units
77
80
  self.item_dnn_hidden_units = item_dnn_hidden_units
78
-
81
+
79
82
  # User tower
80
83
  user_features = []
81
84
  if user_dense_features:
@@ -84,10 +87,10 @@ class YoutubeDNN(BaseMatchModel):
84
87
  user_features.extend(user_sparse_features)
85
88
  if user_sequence_features:
86
89
  user_features.extend(user_sequence_features)
87
-
90
+
88
91
  if len(user_features) > 0:
89
92
  self.user_embedding = EmbeddingLayer(user_features)
90
-
93
+
91
94
  user_input_dim = 0
92
95
  for feat in user_dense_features or []:
93
96
  user_input_dim += 1
@@ -96,16 +99,16 @@ class YoutubeDNN(BaseMatchModel):
96
99
  for feat in user_sequence_features or []:
97
100
  # 序列特征通过平均池化聚合
98
101
  user_input_dim += feat.embedding_dim
99
-
102
+
100
103
  user_dnn_units = user_dnn_hidden_units + [embedding_dim]
101
104
  self.user_dnn = MLP(
102
105
  input_dim=user_input_dim,
103
106
  dims=user_dnn_units,
104
107
  output_layer=False,
105
108
  dropout=dnn_dropout,
106
- activation=dnn_activation
109
+ activation=dnn_activation,
107
110
  )
108
-
111
+
109
112
  # Item tower
110
113
  item_features = []
111
114
  if item_dense_features:
@@ -114,10 +117,10 @@ class YoutubeDNN(BaseMatchModel):
114
117
  item_features.extend(item_sparse_features)
115
118
  if item_sequence_features:
116
119
  item_features.extend(item_sequence_features)
117
-
120
+
118
121
  if len(item_features) > 0:
119
122
  self.item_embedding = EmbeddingLayer(item_features)
120
-
123
+
121
124
  item_input_dim = 0
122
125
  for feat in item_dense_features or []:
123
126
  item_input_dim += 1
@@ -125,48 +128,54 @@ class YoutubeDNN(BaseMatchModel):
125
128
  item_input_dim += feat.embedding_dim
126
129
  for feat in item_sequence_features or []:
127
130
  item_input_dim += feat.embedding_dim
128
-
131
+
129
132
  item_dnn_units = item_dnn_hidden_units + [embedding_dim]
130
133
  self.item_dnn = MLP(
131
134
  input_dim=item_input_dim,
132
135
  dims=item_dnn_units,
133
136
  output_layer=False,
134
137
  dropout=dnn_dropout,
135
- activation=dnn_activation
138
+ activation=dnn_activation,
136
139
  )
137
-
140
+
138
141
  self._register_regularization_weights(
139
- embedding_attr='user_embedding',
140
- include_modules=['user_dnn']
142
+ embedding_attr="user_embedding", include_modules=["user_dnn"]
141
143
  )
142
144
  self._register_regularization_weights(
143
- embedding_attr='item_embedding',
144
- include_modules=['item_dnn']
145
+ embedding_attr="item_embedding", include_modules=["item_dnn"]
145
146
  )
146
-
147
+
147
148
  self.to(device)
148
-
149
+
149
150
  def user_tower(self, user_input: dict) -> torch.Tensor:
150
151
  """
151
152
  User tower
152
153
  处理用户历史行为序列和其他用户特征
153
154
  """
154
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
155
+ all_user_features = (
156
+ self.user_dense_features
157
+ + self.user_sparse_features
158
+ + self.user_sequence_features
159
+ )
155
160
  user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
156
161
  user_emb = self.user_dnn(user_emb)
157
-
162
+
158
163
  # L2 normalization
159
164
  user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
160
-
165
+
161
166
  return user_emb
162
-
167
+
163
168
  def item_tower(self, item_input: dict) -> torch.Tensor:
164
169
  """Item tower"""
165
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
170
+ all_item_features = (
171
+ self.item_dense_features
172
+ + self.item_sparse_features
173
+ + self.item_sequence_features
174
+ )
166
175
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
167
176
  item_emb = self.item_dnn(item_emb)
168
-
177
+
169
178
  # L2 normalization
170
179
  item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
171
-
180
+
172
181
  return item_emb
@@ -17,15 +17,15 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
17
  class ESMM(BaseModel):
18
18
  """
19
19
  Entire Space Multi-Task Model
20
-
20
+
21
21
  ESMM is designed for CVR (Conversion Rate) prediction. It models two related tasks:
22
22
  - CTR task: P(click | impression)
23
23
  - CVR task: P(conversion | click)
24
24
  - CTCVR task (auxiliary): P(click & conversion | impression) = P(click) * P(conversion | click)
25
-
25
+
26
26
  This design addresses the sample selection bias and data sparsity issues in CVR modeling.
27
27
  """
28
-
28
+
29
29
  @property
30
30
  def model_name(self):
31
31
  return "ESMM"
@@ -33,30 +33,34 @@ class ESMM(BaseModel):
33
33
  @property
34
34
  def task_type(self):
35
35
  # ESMM has fixed task types: CTR (binary) and CVR (binary)
36
- return ['binary', 'binary']
37
-
38
- def __init__(self,
39
- dense_features: list[DenseFeature],
40
- sparse_features: list[SparseFeature],
41
- sequence_features: list[SequenceFeature],
42
- ctr_params: dict,
43
- cvr_params: dict,
44
- target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
45
- task: str | list[str] = 'binary',
46
- optimizer: str = "adam",
47
- optimizer_params: dict = {},
48
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
49
- device: str = 'cpu',
50
- model_id: str = "baseline",
51
- embedding_l1_reg=1e-6,
52
- dense_l1_reg=1e-5,
53
- embedding_l2_reg=1e-5,
54
- dense_l2_reg=1e-4):
55
-
36
+ return ["binary", "binary"]
37
+
38
+ def __init__(
39
+ self,
40
+ dense_features: list[DenseFeature],
41
+ sparse_features: list[SparseFeature],
42
+ sequence_features: list[SequenceFeature],
43
+ ctr_params: dict,
44
+ cvr_params: dict,
45
+ target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
46
+ task: str | list[str] = "binary",
47
+ optimizer: str = "adam",
48
+ optimizer_params: dict = {},
49
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
50
+ device: str = "cpu",
51
+ model_id: str = "baseline",
52
+ embedding_l1_reg=1e-6,
53
+ dense_l1_reg=1e-5,
54
+ embedding_l2_reg=1e-5,
55
+ dense_l2_reg=1e-4,
56
+ ):
57
+
56
58
  # ESMM requires exactly 2 targets: ctr and ctcvr
57
59
  if len(target) != 2:
58
- raise ValueError(f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}")
59
-
60
+ raise ValueError(
61
+ f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
62
+ )
63
+
60
64
  super(ESMM, self).__init__(
61
65
  dense_features=dense_features,
62
66
  sparse_features=sparse_features,
@@ -69,13 +73,13 @@ class ESMM(BaseModel):
69
73
  embedding_l2_reg=embedding_l2_reg,
70
74
  dense_l2_reg=dense_l2_reg,
71
75
  early_stop_patience=20,
72
- model_id=model_id
76
+ model_id=model_id,
73
77
  )
74
78
 
75
79
  self.loss = loss
76
80
  if self.loss is None:
77
81
  self.loss = "bce"
78
-
82
+
79
83
  # All features
80
84
  self.all_features = dense_features + sparse_features + sequence_features
81
85
 
@@ -83,46 +87,48 @@ class ESMM(BaseModel):
83
87
  self.embedding = EmbeddingLayer(features=self.all_features)
84
88
 
85
89
  # Calculate input dimension
86
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
87
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
90
+ emb_dim_total = sum(
91
+ [
92
+ f.embedding_dim
93
+ for f in self.all_features
94
+ if not isinstance(f, DenseFeature)
95
+ ]
96
+ )
97
+ dense_input_dim = sum(
98
+ [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
99
+ )
88
100
  input_dim = emb_dim_total + dense_input_dim
89
-
101
+
90
102
  # CTR tower
91
103
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
92
-
104
+
93
105
  # CVR tower
94
106
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
95
107
  self.prediction_layer = PredictionLayer(
96
- task_type=self.task_type,
97
- task_dims=[1, 1]
108
+ task_type=self.task_type, task_dims=[1, 1]
98
109
  )
99
110
 
100
111
  # Register regularization weights
101
112
  self._register_regularization_weights(
102
- embedding_attr='embedding',
103
- include_modules=['ctr_tower', 'cvr_tower']
113
+ embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
104
114
  )
105
115
 
106
- self.compile(
107
- optimizer=optimizer,
108
- optimizer_params=optimizer_params,
109
- loss=loss
110
- )
116
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
111
117
 
112
118
  def forward(self, x):
113
119
  # Get all embeddings and flatten
114
120
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
115
-
121
+
116
122
  # CTR prediction: P(click | impression)
117
123
  ctr_logit = self.ctr_tower(input_flat) # [B, 1]
118
124
  cvr_logit = self.cvr_tower(input_flat) # [B, 1]
119
125
  logits = torch.cat([ctr_logit, cvr_logit], dim=1)
120
126
  preds = self.prediction_layer(logits)
121
127
  ctr, cvr = preds.chunk(2, dim=1)
122
-
128
+
123
129
  # CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
124
130
  ctcvr = ctr * cvr # [B, 1]
125
-
131
+
126
132
  # Output: [CTR, CTCVR]
127
133
  # Note: We supervise CTR with click labels and CTCVR with conversion labels
128
134
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
@@ -17,13 +17,13 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
17
  class MMOE(BaseModel):
18
18
  """
19
19
  Multi-gate Mixture-of-Experts
20
-
20
+
21
21
  MMOE improves upon shared-bottom architecture by using multiple expert networks
22
22
  and task-specific gating networks. Each task has its own gate that learns to
23
23
  weight the contributions of different experts, allowing for both task-specific
24
24
  and shared representations.
25
25
  """
26
-
26
+
27
27
  @property
28
28
  def model_name(self):
29
29
  return "MMOE"
@@ -31,26 +31,28 @@ class MMOE(BaseModel):
31
31
  @property
32
32
  def task_type(self):
33
33
  return self.task if isinstance(self.task, list) else [self.task]
34
-
35
- def __init__(self,
36
- dense_features: list[DenseFeature]=[],
37
- sparse_features: list[SparseFeature]=[],
38
- sequence_features: list[SequenceFeature]=[],
39
- expert_params: dict={},
40
- num_experts: int=3,
41
- tower_params_list: list[dict]=[],
42
- target: list[str]=[],
43
- task: str | list[str] = 'binary',
44
- optimizer: str = "adam",
45
- optimizer_params: dict = {},
46
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
47
- device: str = 'cpu',
48
- model_id: str = "baseline",
49
- embedding_l1_reg=1e-6,
50
- dense_l1_reg=1e-5,
51
- embedding_l2_reg=1e-5,
52
- dense_l2_reg=1e-4):
53
-
34
+
35
+ def __init__(
36
+ self,
37
+ dense_features: list[DenseFeature] = [],
38
+ sparse_features: list[SparseFeature] = [],
39
+ sequence_features: list[SequenceFeature] = [],
40
+ expert_params: dict = {},
41
+ num_experts: int = 3,
42
+ tower_params_list: list[dict] = [],
43
+ target: list[str] = [],
44
+ task: str | list[str] = "binary",
45
+ optimizer: str = "adam",
46
+ optimizer_params: dict = {},
47
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
48
+ device: str = "cpu",
49
+ model_id: str = "baseline",
50
+ embedding_l1_reg=1e-6,
51
+ dense_l1_reg=1e-5,
52
+ embedding_l2_reg=1e-5,
53
+ dense_l2_reg=1e-4,
54
+ ):
55
+
54
56
  super(MMOE, self).__init__(
55
57
  dense_features=dense_features,
56
58
  sparse_features=sparse_features,
@@ -63,20 +65,22 @@ class MMOE(BaseModel):
63
65
  embedding_l2_reg=embedding_l2_reg,
64
66
  dense_l2_reg=dense_l2_reg,
65
67
  early_stop_patience=20,
66
- model_id=model_id
68
+ model_id=model_id,
67
69
  )
68
70
 
69
71
  self.loss = loss
70
72
  if self.loss is None:
71
73
  self.loss = "bce"
72
-
74
+
73
75
  # Number of tasks and experts
74
76
  self.num_tasks = len(target)
75
77
  self.num_experts = num_experts
76
-
78
+
77
79
  if len(tower_params_list) != self.num_tasks:
78
- raise ValueError(f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})")
79
-
80
+ raise ValueError(
81
+ f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
82
+ )
83
+
80
84
  # All features
81
85
  self.all_features = dense_features + sparse_features + sequence_features
82
86
 
@@ -84,78 +88,83 @@ class MMOE(BaseModel):
84
88
  self.embedding = EmbeddingLayer(features=self.all_features)
85
89
 
86
90
  # Calculate input dimension
87
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
88
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
91
+ emb_dim_total = sum(
92
+ [
93
+ f.embedding_dim
94
+ for f in self.all_features
95
+ if not isinstance(f, DenseFeature)
96
+ ]
97
+ )
98
+ dense_input_dim = sum(
99
+ [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
100
+ )
89
101
  input_dim = emb_dim_total + dense_input_dim
90
-
102
+
91
103
  # Expert networks (shared by all tasks)
92
104
  self.experts = nn.ModuleList()
93
105
  for _ in range(num_experts):
94
106
  expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
95
107
  self.experts.append(expert)
96
-
108
+
97
109
  # Get expert output dimension
98
- if 'dims' in expert_params and len(expert_params['dims']) > 0:
99
- expert_output_dim = expert_params['dims'][-1]
110
+ if "dims" in expert_params and len(expert_params["dims"]) > 0:
111
+ expert_output_dim = expert_params["dims"][-1]
100
112
  else:
101
113
  expert_output_dim = input_dim
102
-
114
+
103
115
  # Task-specific gates
104
116
  self.gates = nn.ModuleList()
105
117
  for _ in range(self.num_tasks):
106
- gate = nn.Sequential(
107
- nn.Linear(input_dim, num_experts),
108
- nn.Softmax(dim=1)
109
- )
118
+ gate = nn.Sequential(nn.Linear(input_dim, num_experts), nn.Softmax(dim=1))
110
119
  self.gates.append(gate)
111
-
120
+
112
121
  # Task-specific towers
113
122
  self.towers = nn.ModuleList()
114
123
  for tower_params in tower_params_list:
115
124
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
116
125
  self.towers.append(tower)
117
126
  self.prediction_layer = PredictionLayer(
118
- task_type=self.task_type,
119
- task_dims=[1] * self.num_tasks
127
+ task_type=self.task_type, task_dims=[1] * self.num_tasks
120
128
  )
121
129
 
122
130
  # Register regularization weights
123
131
  self._register_regularization_weights(
124
- embedding_attr='embedding',
125
- include_modules=['experts', 'gates', 'towers']
132
+ embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
126
133
  )
127
134
 
128
- self.compile(
129
- optimizer=optimizer,
130
- optimizer_params=optimizer_params,
131
- loss=loss
132
- )
135
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss)
133
136
 
134
137
  def forward(self, x):
135
138
  # Get all embeddings and flatten
136
139
  input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
137
-
140
+
138
141
  # Expert outputs: [num_experts, B, expert_dim]
139
142
  expert_outputs = [expert(input_flat) for expert in self.experts]
140
- expert_outputs = torch.stack(expert_outputs, dim=0) # [num_experts, B, expert_dim]
141
-
143
+ expert_outputs = torch.stack(
144
+ expert_outputs, dim=0
145
+ ) # [num_experts, B, expert_dim]
146
+
142
147
  # Task-specific processing
143
148
  task_outputs = []
144
149
  for task_idx in range(self.num_tasks):
145
150
  # Gate weights for this task: [B, num_experts]
146
151
  gate_weights = self.gates[task_idx](input_flat) # [B, num_experts]
147
-
152
+
148
153
  # Weighted sum of expert outputs
149
154
  # gate_weights: [B, num_experts, 1]
150
155
  # expert_outputs: [num_experts, B, expert_dim]
151
156
  gate_weights = gate_weights.unsqueeze(2) # [B, num_experts, 1]
152
- expert_outputs_t = expert_outputs.permute(1, 0, 2) # [B, num_experts, expert_dim]
153
- gated_output = torch.sum(gate_weights * expert_outputs_t, dim=1) # [B, expert_dim]
154
-
157
+ expert_outputs_t = expert_outputs.permute(
158
+ 1, 0, 2
159
+ ) # [B, num_experts, expert_dim]
160
+ gated_output = torch.sum(
161
+ gate_weights * expert_outputs_t, dim=1
162
+ ) # [B, expert_dim]
163
+
155
164
  # Tower output
156
165
  tower_output = self.towers[task_idx](gated_output) # [B, 1]
157
166
  task_outputs.append(tower_output)
158
-
167
+
159
168
  # Stack outputs: [B, num_tasks]
160
169
  y = torch.cat(task_outputs, dim=1)
161
170
  return self.prediction_layer(y)