nextrec 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,50 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 09/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
- [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
7
- //Proceedings of the 24th ACM SIGKDD international conference on knowledge discovery & data mining. 2018: 1059-1068.
8
- (https://arxiv.org/abs/1706.06978)
6
+ [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate
7
+ prediction[C] //Proceedings of the 24th ACM SIGKDD international conference on
8
+ knowledge discovery & data mining. 2018: 1059-1068.
9
+ (https://arxiv.org/abs/1706.06978)
10
+
11
+ Deep Interest Network (DIN) is a CTR model that builds a target-aware user
12
+ representation by attending over the historical behavior sequence. Instead of
13
+ compressing all behaviors into one static vector, DIN highlights the behaviors
14
+ most relevant to the current candidate item, enabling adaptive interest
15
+ modeling for each request.
16
+
17
+ Pipeline:
18
+ (1) Embed candidate item, user behavior sequence, and other sparse/dense fields
19
+ (2) Use a small attention MLP to score each historical behavior against the
20
+ candidate embedding
21
+ (3) Apply masked weighted pooling to obtain a target-specific interest vector
22
+ (4) Concatenate candidate, interest vector, other sparse embeddings, and dense
23
+ features
24
+ (5) Feed the combined representation into an MLP for final prediction
25
+
26
+ Key Advantages:
27
+ - Target-aware attention captures fine-grained interests per candidate item
28
+ - Adaptive pooling handles diverse behavior patterns without heavy feature crafting
29
+ - Masked weighting reduces noise from padded sequence positions
30
+ - Easily incorporates additional sparse/dense context features alongside behavior
31
+
32
+ DIN 是一个 CTR 预估模型,通过对用户历史行为序列进行目标感知的注意力加权,
33
+ 构建针对当前候选物品的兴趣表示。它不是将全部行为压缩为固定向量,而是突出
34
+ 与候选物品最相关的行为,实现请求级的自适应兴趣建模。
35
+
36
+ 处理流程:
37
+ (1) 对候选物品、用户行为序列及其他稀疏/稠密特征做 embedding
38
+ (2) 使用小型注意力 MLP 计算每个历史行为与候选 embedding 的相关性
39
+ (3) 通过掩码加权池化得到目标特定的兴趣向量
40
+ (4) 拼接候选、兴趣向量、其他稀疏 embedding 与稠密特征
41
+ (5) 输入 MLP 完成最终点击率预测
42
+
43
+ 主要优点:
44
+ - 目标感知注意力捕捉候选级的细粒度兴趣
45
+ - 自适应池化应对多样化行为模式,减少手工特征工程
46
+ - 掩码加权降低序列填充位置的噪声
47
+ - 便捷融合行为与额外稀疏/稠密上下文信息
9
48
  """
10
49
 
11
50
  import torch
@@ -32,17 +71,19 @@ class DIN(BaseModel):
32
71
 
33
72
  def __init__(
34
73
  self,
35
- dense_features: list[DenseFeature],
36
- sparse_features: list[SparseFeature],
37
- sequence_features: list[SequenceFeature],
38
- mlp_params: dict,
39
- attention_hidden_units: list[int] = [80, 40],
40
- attention_activation: str = "sigmoid",
74
+ dense_features: list[DenseFeature] | None = None,
75
+ sparse_features: list[SparseFeature] | None = None,
76
+ sequence_features: list[SequenceFeature] | None = None,
77
+ behavior_feature_name: str | None = None,
78
+ candidate_feature_name: str | None = None,
79
+ mlp_params: dict | None = None,
80
+ attention_hidden_units: list[int] | None = None,
81
+ attention_activation: str = "dice",
41
82
  attention_use_softmax: bool = True,
42
- target: list[str] = [],
83
+ target: list[str] | str | None = None,
43
84
  task: str | list[str] | None = None,
44
85
  optimizer: str = "adam",
45
- optimizer_params: dict = {},
86
+ optimizer_params: dict | None = None,
46
87
  loss: str | nn.Module | None = "bce",
47
88
  loss_params: dict | list[dict] | None = None,
48
89
  device: str = "cpu",
@@ -53,6 +94,15 @@ class DIN(BaseModel):
53
94
  **kwargs,
54
95
  ):
55
96
 
97
+ dense_features = dense_features or []
98
+ sparse_features = sparse_features or []
99
+ sequence_features = sequence_features or []
100
+ mlp_params = mlp_params or {}
101
+ attention_hidden_units = attention_hidden_units or [80, 40]
102
+ optimizer_params = optimizer_params or {}
103
+ if loss is None:
104
+ loss = "bce"
105
+
56
106
  super(DIN, self).__init__(
57
107
  dense_features=dense_features,
58
108
  sparse_features=sparse_features,
@@ -67,30 +117,28 @@ class DIN(BaseModel):
67
117
  **kwargs,
68
118
  )
69
119
 
70
- self.loss = loss
71
- if self.loss is None:
72
- self.loss = "bce"
73
-
74
- # Features classification
75
- # DIN requires: candidate item + user behavior sequence + other features
120
+ # DIN requires: user behavior sequence + candidate item + other features
76
121
  if len(sequence_features) == 0:
77
122
  raise ValueError(
78
123
  "DIN requires at least one sequence feature for user behavior history"
79
124
  )
125
+ if behavior_feature_name is None:
126
+ raise ValueError("DIN requires an explicit behavior_feature_name")
80
127
 
81
- self.behavior_feature = sequence_features[0] # User behavior sequence
82
- self.candidate_feature = (
83
- sparse_features[-1] if sparse_features else None
84
- ) # Candidate item
128
+ if candidate_feature_name is None:
129
+ raise ValueError("DIN requires an explicit candidate_feature_name")
85
130
 
86
- # Other features (excluding behavior sequence in final concatenation)
87
- self.other_sparse_features = (
88
- sparse_features[:-1] if self.candidate_feature else sparse_features
89
- )
90
- self.dense_features_list = dense_features
131
+ self.behavior_feature = [
132
+ f for f in sequence_features if f.name == behavior_feature_name
133
+ ][0]
134
+ self.candidate_feature = [
135
+ f for f in sparse_features if f.name == candidate_feature_name
136
+ ][0]
91
137
 
92
- # All features for embedding
93
- self.all_features = dense_features + sparse_features + sequence_features
138
+ # Other sparse features
139
+ self.other_sparse_features = [
140
+ f for f in sparse_features if f.name != self.candidate_feature.name
141
+ ]
94
142
 
95
143
  # Embedding layer
96
144
  self.embedding = EmbeddingLayer(features=self.all_features)
@@ -142,14 +190,13 @@ class DIN(BaseModel):
142
190
 
143
191
  def forward(self, x):
144
192
  # Get candidate item embedding
145
- if self.candidate_feature:
146
- candidate_emb = self.embedding.embed_dict[
147
- self.candidate_feature.embedding_name
148
- ](
149
- x[self.candidate_feature.name].long()
150
- ) # [B, emb_dim]
151
- else:
152
- candidate_emb = None
193
+ if self.candidate_feature is None:
194
+ raise ValueError("DIN requires a candidate item feature")
195
+ candidate_emb = self.embedding.embed_dict[
196
+ self.candidate_feature.embedding_name
197
+ ](
198
+ x[self.candidate_feature.name].long()
199
+ ) # [B, emb_dim]
153
200
 
154
201
  # Get behavior sequence embedding
155
202
  behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
@@ -168,24 +215,17 @@ class DIN(BaseModel):
168
215
  mask = (behavior_seq != 0).unsqueeze(-1).float()
169
216
 
170
217
  # Apply attention pooling
171
- if candidate_emb is not None:
172
- candidate_query = candidate_emb
173
- if self.candidate_attention_proj is not None:
174
- candidate_query = self.candidate_attention_proj(candidate_query)
175
- pooled_behavior = self.attention(
176
- query=candidate_query, keys=behavior_emb, mask=mask
177
- ) # [B, emb_dim]
178
- else:
179
- # If no candidate, use mean pooling
180
- pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (
181
- mask.sum(dim=1) + 1e-9
182
- )
218
+ candidate_query = candidate_emb
219
+ if self.candidate_attention_proj is not None:
220
+ candidate_query = self.candidate_attention_proj(candidate_query)
221
+ pooled_behavior = self.attention(
222
+ query=candidate_query, keys=behavior_emb, mask=mask
223
+ ) # [B, emb_dim]
183
224
 
184
225
  # Get other features
185
226
  other_embeddings = []
186
227
 
187
- if candidate_emb is not None:
188
- other_embeddings.append(candidate_emb)
228
+ other_embeddings.append(candidate_emb)
189
229
 
190
230
  other_embeddings.append(pooled_behavior)
191
231
 
@@ -197,11 +237,9 @@ class DIN(BaseModel):
197
237
  other_embeddings.append(feat_emb)
198
238
 
199
239
  # Dense features
200
- for feat in self.dense_features_list:
201
- val = x[feat.name].float()
202
- if val.dim() == 1:
203
- val = val.unsqueeze(1)
204
- other_embeddings.append(val)
240
+ for feat in self.dense_features:
241
+ dense_val = self.embedding.project_dense(feat, x)
242
+ other_embeddings.append(dense_val)
205
243
 
206
244
  # Concatenate all features
207
245
  concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
@@ -1,10 +1,43 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 09/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
- [1] Huang T, Zhang Z, Zhang B, et al. FiBiNET: Combining feature importance and bilinear feature interaction
7
- for click-through rate prediction[C]//RecSys. 2019: 169-177.
6
+ [1] Huang T, Zhang Z, Zhang B, et al. FiBiNET: Combining feature importance and bilinear
7
+ feature interaction for click-through rate prediction[C]//RecSys. 2019: 169-177.
8
+ (https://arxiv.org/abs/1905.09433)
9
+
10
+ FiBiNET (Feature Importance and Bilinear Interaction Network) is a CTR model that
11
+ jointly learns which fields matter most and how they interact. It first uses SENET
12
+ to produce field-wise importance weights and recalibrate embeddings, then applies a
13
+ bilinear interaction layer to capture pairwise feature relationships with enhanced
14
+ expressiveness.
15
+
16
+ Pipeline:
17
+ (1) Embed sparse and sequence features that share a common embedding dimension
18
+ (2) SENET squeezes and excites across fields to generate importance scores
19
+ (3) Reweight embeddings with SENET scores to highlight informative fields
20
+ (4) Compute bilinear interactions on both the original and SENET-reweighted
21
+ embeddings to model pairwise relations
22
+ (5) Concatenate interaction outputs and feed them into an MLP alongside a linear
23
+ term for final prediction
24
+
25
+ Key Advantages:
26
+ - SENET recalibration emphasizes the most informative feature fields
27
+ - Bilinear interactions explicitly model pairwise relationships beyond simple dot
28
+ products
29
+ - Dual-path (standard + SENET-reweighted) interactions enrich representation power
30
+ - Combines linear and deep components for both memorization and generalization
31
+
32
+ FiBiNET 是一个 CTR 预估模型,通过 SENET 重新分配特征字段的重要性,再用双线性
33
+ 交互层捕捉成对特征关系。模型先对稀疏/序列特征做 embedding,SENET 生成字段权重并
34
+ 重标定 embedding,随后在原始和重标定的 embedding 上分别计算双线性交互,最后将
35
+ 交互结果与线性部分一起输入 MLP 得到预测。
36
+ 主要优点:
37
+ - SENET 让模型聚焦最重要的特征字段
38
+ - 双线性交互显式建模特征对关系,表达力强于简单点积
39
+ - 标准与重标定两路交互结合,丰富特征表示
40
+ - 线性与深层结构并行,兼顾记忆与泛化
8
41
  """
9
42
 
10
43
  import torch
@@ -13,6 +46,7 @@ import torch.nn as nn
13
46
  from nextrec.basic.model import BaseModel
14
47
  from nextrec.basic.layers import (
15
48
  BiLinearInteractionLayer,
49
+ HadamardInteractionLayer,
16
50
  EmbeddingLayer,
17
51
  LR,
18
52
  MLP,
@@ -33,16 +67,17 @@ class FiBiNET(BaseModel):
33
67
 
34
68
  def __init__(
35
69
  self,
36
- dense_features: list[DenseFeature] | list = [],
37
- sparse_features: list[SparseFeature] | list = [],
38
- sequence_features: list[SequenceFeature] | list = [],
39
- mlp_params: dict = {},
70
+ dense_features: list[DenseFeature] | None = None,
71
+ sparse_features: list[SparseFeature] | None = None,
72
+ sequence_features: list[SequenceFeature] | None = None,
73
+ mlp_params: dict | None = None,
74
+ interaction_combo: str = "11", # "0": Hadamard, "1": Bilinear
40
75
  bilinear_type: str = "field_interaction",
41
76
  senet_reduction: int = 3,
42
- target: list[str] | list = [],
77
+ target: list[str] | str | None = None,
43
78
  task: str | list[str] | None = None,
44
79
  optimizer: str = "adam",
45
- optimizer_params: dict = {},
80
+ optimizer_params: dict | None = None,
46
81
  loss: str | nn.Module | None = "bce",
47
82
  loss_params: dict | list[dict] | None = None,
48
83
  device: str = "cpu",
@@ -53,6 +88,14 @@ class FiBiNET(BaseModel):
53
88
  **kwargs,
54
89
  ):
55
90
 
91
+ dense_features = dense_features or []
92
+ sparse_features = sparse_features or []
93
+ sequence_features = sequence_features or []
94
+ mlp_params = mlp_params or {}
95
+ optimizer_params = optimizer_params or {}
96
+ if loss is None:
97
+ loss = "bce"
98
+
56
99
  super(FiBiNET, self).__init__(
57
100
  dense_features=dense_features,
58
101
  sparse_features=sparse_features,
@@ -68,11 +111,7 @@ class FiBiNET(BaseModel):
68
111
  )
69
112
 
70
113
  self.loss = loss
71
- if self.loss is None:
72
- self.loss = "bce"
73
-
74
114
  self.linear_features = sparse_features + sequence_features
75
- self.deep_features = dense_features + sparse_features + sequence_features
76
115
  self.interaction_features = sparse_features + sequence_features
77
116
 
78
117
  if len(self.interaction_features) < 2:
@@ -80,7 +119,7 @@ class FiBiNET(BaseModel):
80
119
  "FiBiNET requires at least two sparse/sequence features for interactions."
81
120
  )
82
121
 
83
- self.embedding = EmbeddingLayer(features=self.deep_features)
122
+ self.embedding = EmbeddingLayer(features=self.all_features)
84
123
 
85
124
  self.num_fields = len(self.interaction_features)
86
125
  self.embedding_dim = self.interaction_features[0].embedding_dim
@@ -94,16 +133,34 @@ class FiBiNET(BaseModel):
94
133
  self.senet = SENETLayer(
95
134
  num_fields=self.num_fields, reduction_ratio=senet_reduction
96
135
  )
97
- self.bilinear_standard = BiLinearInteractionLayer(
98
- input_dim=self.embedding_dim,
99
- num_fields=self.num_fields,
100
- bilinear_type=bilinear_type,
101
- )
102
- self.bilinear_senet = BiLinearInteractionLayer(
103
- input_dim=self.embedding_dim,
104
- num_fields=self.num_fields,
105
- bilinear_type=bilinear_type,
106
- )
136
+
137
+ self.interaction_combo = interaction_combo
138
+
139
+ # E interaction layers: original embeddings
140
+ if interaction_combo[0] == "0": # Hadamard
141
+ self.interaction_E = HadamardInteractionLayer(
142
+ num_fields=self.num_fields
143
+ ) # [B, num_pairs, D]
144
+ elif interaction_combo[0] == "1": # Bilinear
145
+ self.interaction_E = BiLinearInteractionLayer(
146
+ input_dim=self.embedding_dim,
147
+ num_fields=self.num_fields,
148
+ bilinear_type=bilinear_type,
149
+ ) # [B, num_pairs, D]
150
+ else:
151
+ raise ValueError("interaction_combo must be '01' or '11'")
152
+
153
+ # V interaction layers: SENET reweighted embeddings
154
+ if interaction_combo[1] == "0":
155
+ self.interaction_V = HadamardInteractionLayer(num_fields=self.num_fields)
156
+ elif interaction_combo[1] == "1":
157
+ self.interaction_V = BiLinearInteractionLayer(
158
+ input_dim=self.embedding_dim,
159
+ num_fields=self.num_fields,
160
+ bilinear_type=bilinear_type,
161
+ )
162
+ else:
163
+ raise ValueError("Deep-FiBiNET SENET side must be '01' or '11'")
107
164
 
108
165
  linear_dim = sum([f.embedding_dim for f in self.linear_features])
109
166
  self.linear = LR(linear_dim)
@@ -119,9 +176,9 @@ class FiBiNET(BaseModel):
119
176
  include_modules=[
120
177
  "linear",
121
178
  "senet",
122
- "bilinear_standard",
123
- "bilinear_senet",
124
179
  "mlp",
180
+ "interaction_E",
181
+ "interaction_V",
125
182
  ],
126
183
  )
127
184
 
@@ -143,9 +200,14 @@ class FiBiNET(BaseModel):
143
200
  )
144
201
  senet_emb = self.senet(field_emb)
145
202
 
146
- bilinear_standard = self.bilinear_standard(field_emb).flatten(start_dim=1)
147
- bilinear_senet = self.bilinear_senet(senet_emb).flatten(start_dim=1)
148
- deep_input = torch.cat([bilinear_standard, bilinear_senet], dim=1)
203
+ out_E = self.interaction_E(field_emb) # [B, num_pairs, D]
204
+
205
+ out_V = self.interaction_V(senet_emb) # [B, num_pairs, D]
206
+
207
+ deep_input = torch.cat(
208
+ [out_E.flatten(start_dim=1), out_V.flatten(start_dim=1)], dim=1
209
+ )
210
+
149
211
  y_deep = self.mlp(deep_input)
150
212
 
151
213
  y = y_linear + y_deep
@@ -1,9 +1,41 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 09/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
- [1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
6
+ [1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
7
+
8
+ Factorization Machines (FM) capture second-order feature interactions with
9
+ linear complexity by factorizing the pairwise interaction matrix. Each field
10
+ is embedded into a latent vector; FM models the dot product of every pair of
11
+ embeddings and sums them along with a linear term, enabling strong performance
12
+ with sparse high-dimensional data and minimal feature engineering.
13
+
14
+ Pipeline:
15
+ (1) Embed sparse and sequence fields into low-dimensional vectors
16
+ (2) Compute linear logit over concatenated embeddings
17
+ (3) Compute pairwise interaction logit via factorized dot products
18
+ (4) Sum linear + interaction terms and apply prediction layer
19
+
20
+ Key Advantages:
21
+ - Models pairwise interactions efficiently (O(nk) vs. O(n^2))
22
+ - Works well on sparse inputs without handcrafted crosses
23
+ - Simple architecture with strong baseline performance
24
+
25
+ FM 是一种通过分解二阶特征交互矩阵、以线性复杂度建模特征对的 CTR 模型。
26
+ 每个特征映射为低维向量,FM 对任意特征对进行内积求和并叠加线性项,
27
+ 无需复杂特征工程即可在稀疏高维场景取得稳健效果。
28
+
29
+ 处理流程:
30
+ (1) 对稀疏/序列特征做 embedding
31
+ (2) 计算线性部分的 logit
32
+ (3) 计算嵌入对之间的二阶交互 logit
33
+ (4) 线性项与交互项求和,再通过预测层输出
34
+
35
+ 主要优点:
36
+ - 线性复杂度建模二阶交互,效率高
37
+ - 对稀疏特征友好,减少人工特征交叉
38
+ - 结构简单、表现强健,常作 CTR 基线
7
39
  """
8
40
 
9
41
  import torch.nn as nn
@@ -29,13 +61,13 @@ class FM(BaseModel):
29
61
 
30
62
  def __init__(
31
63
  self,
32
- dense_features: list[DenseFeature] | list = [],
33
- sparse_features: list[SparseFeature] | list = [],
34
- sequence_features: list[SequenceFeature] | list = [],
35
- target: list[str] | list = [],
64
+ dense_features: list[DenseFeature] | None = None,
65
+ sparse_features: list[SparseFeature] | None = None,
66
+ sequence_features: list[SequenceFeature] | None = None,
67
+ target: list[str] | str | None = None,
36
68
  task: str | list[str] | None = None,
37
69
  optimizer: str = "adam",
38
- optimizer_params: dict = {},
70
+ optimizer_params: dict | None = None,
39
71
  loss: str | nn.Module | None = "bce",
40
72
  loss_params: dict | list[dict] | None = None,
41
73
  device: str = "cpu",
@@ -46,6 +78,10 @@ class FM(BaseModel):
46
78
  **kwargs,
47
79
  ):
48
80
 
81
+ dense_features = dense_features or []
82
+ sparse_features = sparse_features or []
83
+ sequence_features = sequence_features or []
84
+
49
85
  super(FM, self).__init__(
50
86
  dense_features=dense_features,
51
87
  sparse_features=sparse_features,
@@ -166,7 +166,7 @@ class MaskNet(BaseModel):
166
166
  dense_features: list[DenseFeature] | None = None,
167
167
  sparse_features: list[SparseFeature] | None = None,
168
168
  sequence_features: list[SequenceFeature] | None = None,
169
- model_type: str = "parallel", # "serial" or "parallel"
169
+ architecture: str = "parallel", # "serial" or "parallel"
170
170
  num_blocks: int = 3,
171
171
  mask_hidden_dim: int = 64,
172
172
  block_hidden_dim: int = 256,
@@ -232,11 +232,11 @@ class MaskNet(BaseModel):
232
232
  )
233
233
 
234
234
  self.v_emb_dim = self.num_fields * self.embedding_dim
235
- self.model_type = model_type.lower()
236
- assert self.model_type in (
235
+ self.architecture = architecture.lower()
236
+ assert self.architecture in (
237
237
  "serial",
238
238
  "parallel",
239
- ), "model_type must be either 'serial' or 'parallel'."
239
+ ), "architecture must be either 'serial' or 'parallel'."
240
240
 
241
241
  self.num_blocks = max(1, num_blocks)
242
242
  self.block_hidden_dim = block_hidden_dim
@@ -244,7 +244,7 @@ class MaskNet(BaseModel):
244
244
  nn.Dropout(block_dropout) if block_dropout > 0 else nn.Identity()
245
245
  )
246
246
 
247
- if self.model_type == "serial":
247
+ if self.architecture == "serial":
248
248
  self.first_block = MaskBlockOnEmbedding(
249
249
  num_fields=self.num_fields,
250
250
  embedding_dim=self.embedding_dim,
@@ -284,7 +284,7 @@ class MaskNet(BaseModel):
284
284
  self.output_layer = None
285
285
  self.prediction_layer = PredictionLayer(task_type=self.task)
286
286
 
287
- if self.model_type == "serial":
287
+ if self.architecture == "serial":
288
288
  self.register_regularization_weights(
289
289
  embedding_attr="embedding",
290
290
  include_modules=["mask_blocks", "output_layer"],
@@ -306,7 +306,7 @@ class MaskNet(BaseModel):
306
306
  B = field_emb.size(0)
307
307
  v_emb_flat = field_emb.view(B, -1) # flattened embeddings
308
308
 
309
- if self.model_type == "parallel":
309
+ if self.architecture == "parallel":
310
310
  block_outputs = []
311
311
  for block in self.mask_blocks:
312
312
  h = block(field_emb, v_emb_flat) # [B, block_hidden_dim]