torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Utility classes and functions for the HSTU model."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RelPosBias(nn.Module):
|
|
9
|
+
"""Relative position bias module.
|
|
10
|
+
|
|
11
|
+
This module is used in HSTU self-attention layers to provide a learnable
|
|
12
|
+
bias that depends on the relative distance between sequence positions. It
|
|
13
|
+
can be combined with time-based bucketing when needed.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
n_heads (int): Number of attention heads.
|
|
17
|
+
max_seq_len (int): Maximum supported sequence length.
|
|
18
|
+
num_buckets (int): Number of relative position buckets. Default: 32.
|
|
19
|
+
|
|
20
|
+
Shape:
|
|
21
|
+
- Output: ``(1, n_heads, seq_len, seq_len)``
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
>>> rel_pos_bias = RelPosBias(n_heads=8, max_seq_len=256)
|
|
25
|
+
>>> bias = rel_pos_bias(256)
|
|
26
|
+
>>> bias.shape
|
|
27
|
+
torch.Size([1, 8, 256, 256])
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, n_heads, max_seq_len, num_buckets=32):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.n_heads = n_heads
|
|
33
|
+
self.max_seq_len = max_seq_len
|
|
34
|
+
self.num_buckets = num_buckets
|
|
35
|
+
|
|
36
|
+
# 相对位置偏置表: (num_buckets, n_heads)
|
|
37
|
+
self.rel_pos_bias_table = nn.Parameter(torch.randn(num_buckets, n_heads))
|
|
38
|
+
|
|
39
|
+
def _relative_position_bucket(self, relative_position):
|
|
40
|
+
"""Map relative positions to bucket indices.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
relative_position (Tensor): Relative position tensor ``(L, L)``.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tensor: Integer bucket indices with the same ``(L, L)`` shape.
|
|
47
|
+
"""
|
|
48
|
+
num_buckets = self.num_buckets
|
|
49
|
+
max_distance = self.max_seq_len
|
|
50
|
+
|
|
51
|
+
# Use absolute distance and linearly map it to bucket indices
|
|
52
|
+
relative_position = torch.abs(relative_position)
|
|
53
|
+
|
|
54
|
+
bucket = torch.clamp(
|
|
55
|
+
relative_position * (num_buckets - 1) // max_distance,
|
|
56
|
+
0,
|
|
57
|
+
num_buckets - 1,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return bucket.long()
|
|
61
|
+
|
|
62
|
+
def forward(self, seq_len):
|
|
63
|
+
"""Compute relative position bias for a given sequence length.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
seq_len (int): Sequence length ``L``.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tensor: Relative position bias of shape ``(1, n_heads, L, L)``.
|
|
70
|
+
"""
|
|
71
|
+
# 创建位置索引
|
|
72
|
+
positions = torch.arange(seq_len, dtype=torch.long, device=self.rel_pos_bias_table.device)
|
|
73
|
+
|
|
74
|
+
# 计算相对位置: (seq_len, seq_len)
|
|
75
|
+
relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
|
|
76
|
+
|
|
77
|
+
# 映射到bucket
|
|
78
|
+
buckets = self._relative_position_bucket(relative_positions)
|
|
79
|
+
|
|
80
|
+
# 查表获取偏置: (seq_len, seq_len, n_heads)
|
|
81
|
+
bias = self.rel_pos_bias_table[buckets]
|
|
82
|
+
|
|
83
|
+
# 转置为 (1, n_heads, seq_len, seq_len)
|
|
84
|
+
bias = bias.permute(2, 0, 1).unsqueeze(0)
|
|
85
|
+
|
|
86
|
+
return bias
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class VocabMask(nn.Module):
|
|
90
|
+
"""Vocabulary mask used to constrain generation during inference.
|
|
91
|
+
|
|
92
|
+
At inference time this module can be used to mask out invalid item IDs
|
|
93
|
+
so that the model never generates them.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
vocab_size (int): Vocabulary size.
|
|
97
|
+
invalid_items (list, optional): List of invalid item IDs to be masked.
|
|
98
|
+
|
|
99
|
+
Methods:
|
|
100
|
+
apply_mask: Apply the mask to logits.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> mask = VocabMask(vocab_size=1000, invalid_items=[0, 1, 2])
|
|
104
|
+
>>> logits = torch.randn(32, 1000)
|
|
105
|
+
>>> masked_logits = mask.apply_mask(logits)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, vocab_size, invalid_items=None):
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.vocab_size = vocab_size
|
|
111
|
+
|
|
112
|
+
# Create a boolean mask over the vocabulary
|
|
113
|
+
self.register_buffer(
|
|
114
|
+
'mask',
|
|
115
|
+
torch.ones(vocab_size,
|
|
116
|
+
dtype=torch.bool),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Mark invalid items
|
|
120
|
+
if invalid_items is not None:
|
|
121
|
+
for item_id in invalid_items:
|
|
122
|
+
if 0 <= item_id < vocab_size:
|
|
123
|
+
self.mask[item_id] = False
|
|
124
|
+
|
|
125
|
+
def apply_mask(self, logits):
|
|
126
|
+
"""应用掩码到logits.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
logits (Tensor): 模型输出logits,shape: (..., vocab_size)
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Tensor: 掩码后的logits
|
|
133
|
+
"""
|
|
134
|
+
# 将无效item的logits设置为极小值
|
|
135
|
+
masked_logits = logits.clone()
|
|
136
|
+
masked_logits[..., ~self.mask] = -1e9
|
|
137
|
+
|
|
138
|
+
return masked_logits
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class VocabMapper(object):
|
|
142
|
+
"""Simple mapper between ``item_id`` and ``token_id``.
|
|
143
|
+
|
|
144
|
+
In sequence generation tasks we often treat item IDs as tokens. This
|
|
145
|
+
helper keeps a trivial identity mapping but makes the intent explicit and
|
|
146
|
+
allows future extensions (e.g., reserved IDs, remapping, etc.).
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
vocab_size (int): Size of the vocabulary.
|
|
150
|
+
pad_id (int): ID used for the PAD token. Default: 0.
|
|
151
|
+
unk_id (int): ID used for unknown tokens. Default: 1.
|
|
152
|
+
|
|
153
|
+
Methods:
|
|
154
|
+
encode: Map ``item_id`` to ``token_id``.
|
|
155
|
+
decode: Map ``token_id`` back to ``item_id``.
|
|
156
|
+
|
|
157
|
+
Example:
|
|
158
|
+
>>> mapper = VocabMapper(vocab_size=1000)
|
|
159
|
+
>>> item_ids = np.array([10, 20, 30])
|
|
160
|
+
>>> token_ids = mapper.encode(item_ids)
|
|
161
|
+
>>> decoded_ids = mapper.decode(token_ids)
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(self, vocab_size, pad_id=0, unk_id=1):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.vocab_size = vocab_size
|
|
167
|
+
self.pad_id = pad_id
|
|
168
|
+
self.unk_id = unk_id
|
|
169
|
+
|
|
170
|
+
# 创建映射表(简单的恒等映射)
|
|
171
|
+
self.item2token = np.arange(vocab_size)
|
|
172
|
+
self.token2item = np.arange(vocab_size)
|
|
173
|
+
|
|
174
|
+
def encode(self, item_ids):
|
|
175
|
+
"""将item_id转换为token_id.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
item_ids (np.ndarray): item ID数组
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
np.ndarray: token ID数组
|
|
182
|
+
"""
|
|
183
|
+
# 处理超出范围的item_id
|
|
184
|
+
token_ids = np.where((item_ids >= 0) & (item_ids < self.vocab_size), item_ids, self.unk_id)
|
|
185
|
+
return token_ids
|
|
186
|
+
|
|
187
|
+
def decode(self, token_ids):
|
|
188
|
+
"""将token_id转换为item_id.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
token_ids (np.ndarray): token ID数组
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
np.ndarray: item ID数组
|
|
195
|
+
"""
|
|
196
|
+
# 处理超出范围的token_id
|
|
197
|
+
item_ids = np.where((token_ids >= 0) & (token_ids < self.vocab_size), token_ids, self.unk_id)
|
|
198
|
+
return item_ids
|