torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -0,0 +1,249 @@
1
+ """HLLM: Hierarchical Large Language Model for Recommendation."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from torch_rechub.utils.hstu_utils import RelPosBias
10
+
11
+
12
+ class HLLMTransformerBlock(nn.Module):
13
+ """Single HLLM Transformer block with self-attention and FFN.
14
+
15
+ This block is similar to HSTULayer but designed for HLLM which uses
16
+ pre-computed item embeddings as input instead of learnable token embeddings.
17
+
18
+ Args:
19
+ d_model (int): Hidden dimension.
20
+ n_heads (int): Number of attention heads.
21
+ dropout (float): Dropout rate.
22
+ """
23
+
24
+ def __init__(self, d_model=512, n_heads=8, dropout=0.1):
25
+ super().__init__()
26
+ self.d_model = d_model
27
+ self.n_heads = n_heads
28
+
29
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
30
+ self.head_dim = d_model // n_heads
31
+ self.scale = self.head_dim**-0.5
32
+
33
+ # Multi-head self-attention
34
+ self.W_Q = nn.Linear(d_model, d_model)
35
+ self.W_K = nn.Linear(d_model, d_model)
36
+ self.W_V = nn.Linear(d_model, d_model)
37
+ self.W_O = nn.Linear(d_model, d_model)
38
+
39
+ # Feed-forward network
40
+ ffn_hidden = 4 * d_model
41
+ self.ffn = nn.Sequential(nn.Linear(d_model, ffn_hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_hidden, d_model), nn.Dropout(dropout))
42
+
43
+ # Layer normalization and dropout
44
+ self.norm1 = nn.LayerNorm(d_model)
45
+ self.norm2 = nn.LayerNorm(d_model)
46
+ self.dropout = nn.Dropout(dropout)
47
+
48
+ def forward(self, x, rel_pos_bias=None):
49
+ """Forward pass.
50
+
51
+ Args:
52
+ x (Tensor): Input of shape (B, L, D).
53
+ rel_pos_bias (Tensor, optional): Relative position bias.
54
+
55
+ Returns:
56
+ Tensor: Output of shape (B, L, D).
57
+ """
58
+ batch_size, seq_len, _ = x.shape
59
+
60
+ # Self-attention with residual
61
+ residual = x
62
+ x = self.norm1(x)
63
+
64
+ Q = self.W_Q(x) # (B, L, D)
65
+ K = self.W_K(x)
66
+ V = self.W_V(x)
67
+
68
+ # Reshape for multi-head attention
69
+ Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
70
+ K = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
71
+ V = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
72
+
73
+ # Attention scores
74
+ scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
75
+
76
+ # Causal mask
77
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
78
+ scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
79
+
80
+ # Add relative position bias if provided
81
+ if rel_pos_bias is not None:
82
+ scores = scores + rel_pos_bias
83
+
84
+ attn_weights = F.softmax(scores, dim=-1)
85
+ attn_weights = self.dropout(attn_weights)
86
+
87
+ attn_output = torch.matmul(attn_weights, V)
88
+ attn_output = attn_output.transpose(1, 2).contiguous()
89
+ attn_output = attn_output.view(batch_size, seq_len, self.d_model)
90
+ attn_output = self.W_O(attn_output)
91
+ attn_output = self.dropout(attn_output)
92
+
93
+ x = residual + attn_output
94
+
95
+ # FFN with residual
96
+ residual = x
97
+ x = self.norm2(x)
98
+ x = self.ffn(x)
99
+ x = residual + x
100
+
101
+ return x
102
+
103
+
104
+ class HLLMModel(nn.Module):
105
+ """HLLM: Hierarchical Large Language Model for Recommendation.
106
+
107
+ This is a lightweight implementation of HLLM that uses pre-computed item
108
+ embeddings as input. The original ByteDance HLLM uses end-to-end training
109
+ with both Item LLM and User LLM, but this implementation focuses on the
110
+ User LLM component for resource efficiency.
111
+
112
+ Architecture:
113
+ - Item Embeddings: Pre-computed using LLM (offline, frozen)
114
+ Format: "{item_prompt}title: {title}description: {description}"
115
+ where item_prompt = "Compress the following sentence into embedding: "
116
+ - User LLM: Transformer blocks that model user sequences (trainable)
117
+ - Scoring Head: Dot product between user representation and item embeddings
118
+
119
+ Reference:
120
+ ByteDance HLLM: https://github.com/bytedance/HLLM
121
+
122
+ Args:
123
+ item_embeddings (Tensor or str): Pre-computed item embeddings of shape
124
+ (vocab_size, d_model), or path to a .pt file containing embeddings.
125
+ Generated using the last token's hidden state from an LLM.
126
+ vocab_size (int): Vocabulary size (number of items).
127
+ d_model (int): Hidden dimension. Should match item embedding dimension.
128
+ Default: 512. TinyLlama uses 2048, Baichuan2 uses 4096.
129
+ n_heads (int): Number of attention heads. Default: 8.
130
+ n_layers (int): Number of transformer blocks. Default: 4.
131
+ max_seq_len (int): Maximum sequence length. Default: 256.
132
+ Official uses MAX_ITEM_LIST_LENGTH=50.
133
+ dropout (float): Dropout rate. Default: 0.1.
134
+ use_rel_pos_bias (bool): Whether to use relative position bias. Default: True.
135
+ use_time_embedding (bool): Whether to use time embeddings. Default: True.
136
+ num_time_buckets (int): Number of time buckets. Default: 2048.
137
+ time_bucket_fn (str): Time bucketization function ('sqrt' or 'log'). Default: 'sqrt'.
138
+ temperature (float): Temperature for NCE scoring. Default: 1.0.
139
+ Official uses logit_scale = log(1/0.07) ≈ 2.66.
140
+ """
141
+
142
+ def __init__(self, item_embeddings, vocab_size, d_model=512, n_heads=8, n_layers=4, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt', temperature=1.0):
143
+ super().__init__()
144
+ self.vocab_size = vocab_size
145
+ self.d_model = d_model
146
+ self.n_heads = n_heads
147
+ self.n_layers = n_layers
148
+ self.max_seq_len = max_seq_len
149
+ self.use_time_embedding = use_time_embedding
150
+ self.num_time_buckets = num_time_buckets
151
+ self.time_bucket_fn = time_bucket_fn
152
+ self.temperature = temperature
153
+
154
+ # Load item embeddings
155
+ if isinstance(item_embeddings, str):
156
+ item_embeddings = torch.load(item_embeddings)
157
+
158
+ # Register as buffer (not trainable)
159
+ self.register_buffer('item_embeddings', item_embeddings.float())
160
+
161
+ # Positional embedding
162
+ self.position_embedding = nn.Embedding(max_seq_len, d_model)
163
+
164
+ # Time embedding
165
+ if use_time_embedding:
166
+ self.time_embedding = nn.Embedding(num_time_buckets + 1, d_model, padding_idx=0)
167
+
168
+ # Transformer blocks
169
+ self.transformer_blocks = nn.ModuleList([HLLMTransformerBlock(d_model=d_model, n_heads=n_heads, dropout=dropout) for _ in range(n_layers)])
170
+
171
+ # Relative position bias
172
+ self.use_rel_pos_bias = use_rel_pos_bias
173
+ if use_rel_pos_bias:
174
+ self.rel_pos_bias = RelPosBias(n_heads, max_seq_len)
175
+
176
+ # Dropout
177
+ self.dropout = nn.Dropout(dropout)
178
+
179
+ self._init_weights()
180
+
181
+ def _init_weights(self):
182
+ """Initialize model parameters."""
183
+ for name, param in self.named_parameters():
184
+ if 'weight' in name and len(param.shape) > 1:
185
+ nn.init.xavier_uniform_(param)
186
+ elif 'bias' in name:
187
+ nn.init.constant_(param, 0)
188
+
189
+ def _time_diff_to_bucket(self, time_diffs):
190
+ """Map time differences to bucket indices."""
191
+ time_diffs = time_diffs.float() / 60.0 # seconds to minutes
192
+ time_diffs = torch.clamp(time_diffs, min=1e-6)
193
+
194
+ if self.time_bucket_fn == 'sqrt':
195
+ buckets = torch.sqrt(time_diffs).long()
196
+ elif self.time_bucket_fn == 'log':
197
+ buckets = torch.log(time_diffs).long()
198
+ else:
199
+ raise ValueError(f"Unsupported time_bucket_fn: {self.time_bucket_fn}")
200
+
201
+ buckets = torch.clamp(buckets, min=0, max=self.num_time_buckets - 1)
202
+ return buckets
203
+
204
+ def forward(self, seq_tokens, time_diffs=None):
205
+ """Forward pass.
206
+
207
+ Args:
208
+ seq_tokens (Tensor): Item token IDs of shape (B, L).
209
+ time_diffs (Tensor, optional): Time differences in seconds of shape (B, L).
210
+
211
+ Returns:
212
+ Tensor: Logits of shape (B, L, vocab_size).
213
+ """
214
+ batch_size, seq_len = seq_tokens.shape
215
+
216
+ # Look up item embeddings
217
+ item_emb = self.item_embeddings[seq_tokens] # (B, L, D)
218
+
219
+ # Add positional embedding
220
+ positions = torch.arange(seq_len, dtype=torch.long, device=seq_tokens.device)
221
+ pos_emb = self.position_embedding(positions) # (L, D)
222
+ embeddings = item_emb + pos_emb.unsqueeze(0) # (B, L, D)
223
+
224
+ # Add time embedding if provided
225
+ if self.use_time_embedding:
226
+ if time_diffs is None:
227
+ time_diffs = torch.zeros(batch_size, seq_len, dtype=torch.long, device=seq_tokens.device)
228
+
229
+ time_buckets = self._time_diff_to_bucket(time_diffs)
230
+ time_emb = self.time_embedding(time_buckets) # (B, L, D)
231
+ embeddings = embeddings + time_emb
232
+
233
+ embeddings = self.dropout(embeddings)
234
+
235
+ # Get relative position bias
236
+ rel_pos_bias = None
237
+ if self.use_rel_pos_bias:
238
+ rel_pos_bias = self.rel_pos_bias(seq_len)
239
+
240
+ # Pass through transformer blocks
241
+ x = embeddings
242
+ for block in self.transformer_blocks:
243
+ x = block(x, rel_pos_bias=rel_pos_bias)
244
+
245
+ # Scoring head: compute dot product with item embeddings
246
+ # x: (B, L, D), item_embeddings: (V, D)
247
+ logits = torch.matmul(x, self.item_embeddings.t()) / self.temperature # (B, L, V)
248
+
249
+ return logits
@@ -0,0 +1,189 @@
1
+ """HSTU: Hierarchical Sequential Transduction Units Model."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from torch_rechub.basic.layers import HSTUBlock
9
+ from torch_rechub.utils.hstu_utils import RelPosBias
10
+
11
+
12
+ class HSTUModel(nn.Module):
13
+ """HSTU: Hierarchical Sequential Transduction Units model.
14
+
15
+ Autoregressive generative recommendation model for sequential data.
16
+ This module stacks multiple ``HSTUBlock`` layers to capture long-range
17
+ dependencies in user interaction sequences and predicts the next item.
18
+
19
+ Args:
20
+ vocab_size (int): Vocabulary size (number of distinct items, including PAD).
21
+ d_model (int): Hidden dimension of the model. Default: 512.
22
+ n_heads (int): Number of attention heads. Default: 8.
23
+ n_layers (int): Number of stacked HSTU layers. Default: 4.
24
+ dqk (int): Dimension of query/key vectors per head. Default: 64.
25
+ dv (int): Dimension of value vectors per head. Default: 64.
26
+ max_seq_len (int): Maximum sequence length. Default: 256.
27
+ dropout (float): Dropout rate applied in the model. Default: 0.1.
28
+ use_rel_pos_bias (bool): Whether to use relative position bias. Default: True.
29
+ use_time_embedding (bool): Whether to use time-difference embeddings. Default: True.
30
+ num_time_buckets (int): Number of time buckets for time embeddings. Default: 2048.
31
+ time_bucket_fn (str): Function used to bucketize time differences, ``"sqrt"``
32
+ or ``"log"``. Default: ``"sqrt"``.
33
+
34
+ Shape:
35
+ - Input: ``x`` of shape ``(batch_size, seq_len)``; optional ``time_diffs``
36
+ of shape ``(batch_size, seq_len)`` representing time differences in seconds.
37
+ - Output: Logits of shape ``(batch_size, seq_len, vocab_size)``.
38
+
39
+ Example:
40
+ >>> model = HSTUModel(vocab_size=100000, d_model=512)
41
+ >>> x = torch.randint(0, 100000, (32, 256))
42
+ >>> time_diffs = torch.randint(0, 86400, (32, 256))
43
+ >>> logits = model(x, time_diffs)
44
+ >>> logits.shape
45
+ torch.Size([32, 256, 100000])
46
+ """
47
+
48
+ def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt'):
49
+ super().__init__()
50
+ self.vocab_size = vocab_size
51
+ self.d_model = d_model
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.max_seq_len = max_seq_len
55
+ self.use_time_embedding = use_time_embedding
56
+ self.num_time_buckets = num_time_buckets
57
+ self.time_bucket_fn = time_bucket_fn
58
+
59
+ # Alpha scaling factor (following the Meta reference implementation)
60
+ self.alpha = math.sqrt(d_model)
61
+
62
+ # Token embedding
63
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
64
+
65
+ # Absolute positional embedding
66
+ self.position_embedding = nn.Embedding(max_seq_len, d_model)
67
+
68
+ # Time embedding
69
+ if use_time_embedding:
70
+ # Embedding table for time-difference buckets
71
+ # num_time_buckets + 1: extra bucket reserved for padding
72
+ self.time_embedding = nn.Embedding(num_time_buckets + 1, d_model, padding_idx=0)
73
+
74
+ # HSTU block
75
+ self.hstu_block = HSTUBlock(d_model=d_model, n_heads=n_heads, n_layers=n_layers, dqk=dqk, dv=dv, dropout=dropout, use_rel_pos_bias=use_rel_pos_bias)
76
+
77
+ # Relative position bias
78
+ self.use_rel_pos_bias = use_rel_pos_bias
79
+ if use_rel_pos_bias:
80
+ self.rel_pos_bias = RelPosBias(n_heads, max_seq_len)
81
+
82
+ # 输出投影层
83
+ self.output_projection = nn.Linear(d_model, vocab_size)
84
+
85
+ # Dropout
86
+ self.dropout = nn.Dropout(dropout)
87
+
88
+ # Initialize model parameters
89
+ self._init_weights()
90
+
91
+ def _init_weights(self):
92
+ """Initialize model parameters."""
93
+ for name, param in self.named_parameters():
94
+ if 'weight' in name and len(param.shape) > 1:
95
+ nn.init.xavier_uniform_(param)
96
+ elif 'bias' in name:
97
+ nn.init.constant_(param, 0)
98
+
99
+ def _time_diff_to_bucket(self, time_diffs):
100
+ """Map raw time differences (in seconds) to discrete bucket indices.
101
+
102
+ Following the Meta HSTU implementation, continuous time differences are
103
+ first converted to minutes and then bucketized using either a square-root
104
+ or logarithmic transform.
105
+
106
+ Args:
107
+ time_diffs (Tensor): Time differences in seconds,
108
+ shape ``(batch_size, seq_len)``.
109
+
110
+ Returns:
111
+ Tensor: Integer bucket indices of shape ``(batch_size, seq_len)``.
112
+ """
113
+ # Convert seconds to minutes (as in the Meta reference implementation)
114
+ time_bucket_increments = 60.0
115
+ time_diffs = time_diffs.float() / time_bucket_increments
116
+
117
+ # Ensure non-negative values and avoid log(0)
118
+ time_diffs = torch.clamp(time_diffs, min=1e-6)
119
+
120
+ if self.time_bucket_fn == 'sqrt':
121
+ # Use the square-root transform: suitable when time differences
122
+ # are relatively evenly distributed.
123
+ buckets = torch.sqrt(time_diffs).long()
124
+ elif self.time_bucket_fn == 'log':
125
+ # Use the logarithmic transform: suitable when time differences
126
+ # span several orders of magnitude.
127
+ buckets = torch.log(time_diffs).long()
128
+ else:
129
+ raise ValueError(f"Unsupported time_bucket_fn: {self.time_bucket_fn}")
130
+
131
+ # Clamp bucket indices to the valid range [0, num_time_buckets - 1]
132
+ buckets = torch.clamp(buckets, min=0, max=self.num_time_buckets - 1)
133
+
134
+ return buckets
135
+
136
+ def forward(self, x, time_diffs=None):
137
+ """Forward pass.
138
+
139
+ Args:
140
+ x (Tensor): Input token ids of shape ``(batch_size, seq_len)``.
141
+ time_diffs (Tensor, optional): Time differences in seconds,
142
+ shape ``(batch_size, seq_len)``. If ``None`` and
143
+ ``use_time_embedding=True``, all-zero time differences are used.
144
+
145
+ Returns:
146
+ Tensor: Logits over the vocabulary of shape
147
+ ``(batch_size, seq_len, vocab_size)``.
148
+ """
149
+ batch_size, seq_len = x.shape
150
+
151
+ # Token embedding with alpha scaling (as in the Meta implementation)
152
+ token_emb = self.token_embedding(x) * self.alpha # (B, L, D)
153
+
154
+ # Absolute positional embedding
155
+ positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
156
+ pos_emb = self.position_embedding(positions) # (L, D)
157
+
158
+ # Combine token and position embeddings
159
+ embeddings = token_emb + pos_emb.unsqueeze(0) # (B, L, D)
160
+
161
+ # Optional time-difference embedding
162
+ if self.use_time_embedding:
163
+ if time_diffs is None:
164
+ # Fallback: use all-zero time differences when none are provided
165
+ time_diffs = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device)
166
+
167
+ # Map raw time differences to bucket indices
168
+ time_buckets = self._time_diff_to_bucket(time_diffs) # (B, L)
169
+
170
+ # Look up time embeddings and add to the sequence representation
171
+ time_emb = self.time_embedding(time_buckets) # (B, L, D)
172
+
173
+ # embeddings = token_emb + pos_emb + time_emb
174
+ embeddings = embeddings + time_emb
175
+
176
+ embeddings = self.dropout(embeddings)
177
+
178
+ # Relative position bias for self-attention
179
+ rel_pos_bias = None
180
+ if self.use_rel_pos_bias:
181
+ rel_pos_bias = self.rel_pos_bias(seq_len) # (1, H, L, L)
182
+
183
+ # HSTU block
184
+ hstu_output = self.hstu_block(embeddings, rel_pos_bias=rel_pos_bias) # (B, L, D)
185
+
186
+ # Final projection to vocabulary logits
187
+ logits = self.output_projection(hstu_output) # (B, L, V)
188
+
189
+ return logits
@@ -1,11 +1,13 @@
1
- from .dssm import DSSM
2
- from .youtube_dnn import YoutubeDNN
3
- from .youtube_sbc import YoutubeSBC
4
- from .dssm_facebook import FaceBookDSSM
5
- from .gru4rec import GRU4Rec
6
- from .comirec import ComirecSA, ComirecDR
7
- from .mind import MIND
8
- from .narm import NARM
9
- from .stamp import STAMP
10
- from .sasrec import SASRec
11
- from .sine import SINE
1
+ __all__ = ['DSSM', 'FaceBookDSSM', 'YoutubeDNN', 'YoutubeSBC', 'MIND', 'GRU4Rec', 'NARM', 'SASRec', 'SINE', 'STAMP', 'ComirecDR', 'ComirecSA']
2
+
3
+ from .comirec import ComirecDR, ComirecSA
4
+ from .dssm import DSSM
5
+ from .dssm_facebook import FaceBookDSSM
6
+ from .gru4rec import GRU4Rec
7
+ from .mind import MIND
8
+ from .narm import NARM
9
+ from .sasrec import SASRec
10
+ from .sine import SINE
11
+ from .stamp import STAMP
12
+ from .youtube_dnn import YoutubeDNN
13
+ from .youtube_sbc import YoutubeSBC