torch-rechub 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +15 -9
- torch_rechub/utils/data.py +28 -12
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.2.0.dist-info}/METADATA +27 -18
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.2.0.dist-info}/RECORD +6 -6
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.2.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.2.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/basic/layers.py
CHANGED
|
@@ -846,7 +846,7 @@ class HSTULayer(nn.Module):
|
|
|
846
846
|
self.dropout = nn.Dropout(dropout)
|
|
847
847
|
|
|
848
848
|
# Scaling factor for attention scores
|
|
849
|
-
self.scale = 1.0 / (dqk**0.5)
|
|
849
|
+
# self.scale = 1.0 / (dqk**0.5) # Removed in favor of L2 norm + SiLU
|
|
850
850
|
|
|
851
851
|
def forward(self, x, rel_pos_bias=None):
|
|
852
852
|
"""Forward pass of a single HSTU layer.
|
|
@@ -878,6 +878,10 @@ class HSTULayer(nn.Module):
|
|
|
878
878
|
u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
879
879
|
v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
880
880
|
|
|
881
|
+
# Apply L2 normalization to Q and K (HSTU specific)
|
|
882
|
+
q = F.normalize(q, p=2, dim=-1)
|
|
883
|
+
k = F.normalize(k, p=2, dim=-1)
|
|
884
|
+
|
|
881
885
|
# Transpose to (B, H, L, dqk/dv)
|
|
882
886
|
q = q.transpose(1, 2) # (B, H, L, dqk)
|
|
883
887
|
k = k.transpose(1, 2) # (B, H, L, dqk)
|
|
@@ -885,20 +889,22 @@ class HSTULayer(nn.Module):
|
|
|
885
889
|
v = v.transpose(1, 2) # (B, H, L, dv)
|
|
886
890
|
|
|
887
891
|
# Compute attention scores: (B, H, L, L)
|
|
888
|
-
|
|
892
|
+
# Note: No scaling factor here as we use L2 norm + SiLU
|
|
893
|
+
scores = torch.matmul(q, k.transpose(-2, -1))
|
|
894
|
+
|
|
895
|
+
# Add relative position bias if provided (before masking/activation)
|
|
896
|
+
if rel_pos_bias is not None:
|
|
897
|
+
scores = scores + rel_pos_bias
|
|
889
898
|
|
|
890
899
|
# Add causal mask (prevent attending to future positions)
|
|
891
900
|
# For generative models this is required so that position i only attends
|
|
892
901
|
# to positions <= i.
|
|
893
902
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
# Add relative position bias if provided
|
|
897
|
-
if rel_pos_bias is not None:
|
|
898
|
-
scores = scores + rel_pos_bias
|
|
903
|
+
# Use a large negative number for masking compatible with SiLU
|
|
904
|
+
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), -1e4)
|
|
899
905
|
|
|
900
|
-
#
|
|
901
|
-
attn_weights = F.
|
|
906
|
+
# SiLU activation over attention scores (HSTU specific)
|
|
907
|
+
attn_weights = F.silu(scores)
|
|
902
908
|
attn_weights = self.dropout(attn_weights)
|
|
903
909
|
|
|
904
910
|
# Attention output: (B, H, L, dv)
|
torch_rechub/utils/data.py
CHANGED
|
@@ -482,41 +482,57 @@ class SequenceDataGenerator(object):
|
|
|
482
482
|
# Underlying dataset
|
|
483
483
|
self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
484
484
|
|
|
485
|
-
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
|
|
486
|
-
"""Generate
|
|
485
|
+
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None, shuffle=True):
|
|
486
|
+
"""Generate dataloader(s) from the dataset.
|
|
487
487
|
|
|
488
488
|
Parameters
|
|
489
489
|
----------
|
|
490
490
|
batch_size : int, default=32
|
|
491
|
+
Batch size for DataLoader.
|
|
491
492
|
num_workers : int, default=0
|
|
492
|
-
|
|
493
|
-
|
|
493
|
+
Number of workers for DataLoader.
|
|
494
|
+
split_ratio : tuple or None, default=None
|
|
495
|
+
If None, returns a single DataLoader without splitting the data.
|
|
496
|
+
If tuple (e.g., (0.7, 0.1, 0.2)), splits dataset and returns
|
|
497
|
+
(train_loader, val_loader, test_loader).
|
|
498
|
+
shuffle : bool, default=True
|
|
499
|
+
Whether to shuffle data. Only applies when split_ratio is None.
|
|
500
|
+
When split_ratio is provided, train data is always shuffled.
|
|
494
501
|
|
|
495
502
|
Returns
|
|
496
503
|
-------
|
|
497
504
|
tuple
|
|
498
|
-
(
|
|
505
|
+
If split_ratio is None: returns (dataloader,)
|
|
506
|
+
If split_ratio is provided: returns (train_loader, val_loader, test_loader)
|
|
507
|
+
|
|
508
|
+
Examples
|
|
509
|
+
--------
|
|
510
|
+
# Case 1: Data already split, just create loader
|
|
511
|
+
>>> train_gen = SequenceDataGenerator(train_data['seq_tokens'], ...)
|
|
512
|
+
>>> train_loader = train_gen.generate_dataloader(batch_size=32)[0]
|
|
513
|
+
|
|
514
|
+
# Case 2: Auto-split data into train/val/test
|
|
515
|
+
>>> all_gen = SequenceDataGenerator(all_data['seq_tokens'], ...)
|
|
516
|
+
>>> train_loader, val_loader, test_loader = all_gen.generate_dataloader(
|
|
517
|
+
... batch_size=32, split_ratio=(0.7, 0.1, 0.2))
|
|
499
518
|
"""
|
|
500
519
|
if split_ratio is None:
|
|
501
|
-
|
|
520
|
+
# No split - data is already divided, just create a single DataLoader
|
|
521
|
+
dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
522
|
+
return (dataloader,)
|
|
502
523
|
|
|
503
|
-
#
|
|
524
|
+
# Split data into train/val/test
|
|
504
525
|
assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
|
|
505
526
|
|
|
506
|
-
# 计算分割大小
|
|
507
527
|
total_size = len(self.dataset)
|
|
508
528
|
train_size = int(total_size * split_ratio[0])
|
|
509
529
|
val_size = int(total_size * split_ratio[1])
|
|
510
530
|
test_size = total_size - train_size - val_size
|
|
511
531
|
|
|
512
|
-
# 分割数据集
|
|
513
532
|
train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
|
|
514
533
|
|
|
515
|
-
# 创建数据加载器
|
|
516
534
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
517
|
-
|
|
518
535
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
519
|
-
|
|
520
536
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
521
537
|
|
|
522
538
|
return train_loader, val_loader, test_loader
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -31,7 +31,7 @@ Requires-Dist: transformers>=4.46.3
|
|
|
31
31
|
Provides-Extra: annoy
|
|
32
32
|
Requires-Dist: annoy>=1.17.2; extra == 'annoy'
|
|
33
33
|
Provides-Extra: bigdata
|
|
34
|
-
Requires-Dist: pyarrow
|
|
34
|
+
Requires-Dist: pyarrow<23,>=21; extra == 'bigdata'
|
|
35
35
|
Provides-Extra: dev
|
|
36
36
|
Requires-Dist: bandit>=1.7.0; extra == 'dev'
|
|
37
37
|
Requires-Dist: flake8>=3.8.0; extra == 'dev'
|
|
@@ -60,9 +60,11 @@ Requires-Dist: graphviz>=0.20; extra == 'visualization'
|
|
|
60
60
|
Requires-Dist: torchview>=0.2.6; extra == 'visualization'
|
|
61
61
|
Description-Content-Type: text/markdown
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
<div align="center">
|
|
64
64
|
|
|
65
|
-
|
|
65
|
+

|
|
66
|
+
|
|
67
|
+
# Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
|
|
66
68
|
|
|
67
69
|
[](LICENSE)
|
|
68
70
|

|
|
@@ -78,21 +80,13 @@ Description-Content-Type: text/markdown
|
|
|
78
80
|
|
|
79
81
|
[English](README_en.md) | 简体中文
|
|
80
82
|
|
|
81
|
-
|
|
83
|
+

|
|
82
84
|
|
|
83
|
-
|
|
85
|
+
</div>
|
|
84
86
|
|
|
85
|
-
|
|
87
|
+
**在线文档:** https://datawhalechina.github.io/torch-rechub/zh/
|
|
86
88
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
| 特性 | Torch-RecHub | 其他框架 |
|
|
90
|
-
| ------------- | --------------------------- | ---------- |
|
|
91
|
-
| 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
|
|
92
|
-
| 模型覆盖 | **30+** 主流模型 | 有限 |
|
|
93
|
-
| 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
|
|
94
|
-
| ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
|
|
95
|
-
| 学习曲线 | 极低 | 陡峭 |
|
|
89
|
+
**Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
|
|
96
90
|
|
|
97
91
|
## ✨ 特性
|
|
98
92
|
|
|
@@ -109,7 +103,6 @@ Description-Content-Type: text/markdown
|
|
|
109
103
|
## 📖 目录
|
|
110
104
|
|
|
111
105
|
- [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
|
|
112
|
-
- [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
|
|
113
106
|
- [✨ 特性](#-特性)
|
|
114
107
|
- [📖 目录](#-目录)
|
|
115
108
|
- [🔧 安装](#-安装)
|
|
@@ -221,6 +214,8 @@ torch-rechub/ # 根目录
|
|
|
221
214
|
|
|
222
215
|
本框架目前支持 **30+** 主流推荐模型:
|
|
223
216
|
|
|
217
|
+
<details>
|
|
218
|
+
|
|
224
219
|
### 排序模型 (Ranking Models) - 13个
|
|
225
220
|
|
|
226
221
|
| 模型 | 论文 | 简介 |
|
|
@@ -236,7 +231,11 @@ torch-rechub/ # 根目录
|
|
|
236
231
|
| **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
|
|
237
232
|
| **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
|
|
238
233
|
| **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
|
|
239
|
-
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
|
|
234
|
+
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
|
|
235
|
+
|
|
|
236
|
+
</details>
|
|
237
|
+
|
|
238
|
+
<details>
|
|
240
239
|
|
|
241
240
|
### 召回模型 (Matching Models) - 12个
|
|
242
241
|
|
|
@@ -253,6 +252,10 @@ torch-rechub/ # 根目录
|
|
|
253
252
|
| **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
|
|
254
253
|
| **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
|
|
255
254
|
|
|
255
|
+
</details>
|
|
256
|
+
|
|
257
|
+
<details>
|
|
258
|
+
|
|
256
259
|
### 多任务模型 (Multi-Task Models) - 5个
|
|
257
260
|
|
|
258
261
|
| 模型 | 论文 | 简介 |
|
|
@@ -263,6 +266,10 @@ torch-rechub/ # 根目录
|
|
|
263
266
|
| **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
|
|
264
267
|
| **SharedBottom** | - | 经典多任务共享底层 |
|
|
265
268
|
|
|
269
|
+
</details>
|
|
270
|
+
|
|
271
|
+
<details>
|
|
272
|
+
|
|
266
273
|
### 生成式推荐 (Generative Recommendation) - 2个
|
|
267
274
|
|
|
268
275
|
| 模型 | 论文 | 简介 |
|
|
@@ -270,6 +277,8 @@ torch-rechub/ # 根目录
|
|
|
270
277
|
| **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
|
|
271
278
|
| **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
|
|
272
279
|
|
|
280
|
+
</details>
|
|
281
|
+
|
|
273
282
|
## 📊 支持的数据集
|
|
274
283
|
|
|
275
284
|
框架内置了对以下常见数据集格式的支持或提供了处理脚本:
|
|
@@ -5,7 +5,7 @@ torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4Gej
|
|
|
5
5
|
torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
|
|
6
6
|
torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
|
|
7
7
|
torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
|
|
8
|
-
torch_rechub/basic/layers.py,sha256=
|
|
8
|
+
torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
|
|
9
9
|
torch_rechub/basic/loss_func.py,sha256=a-j1gan4eYUk5zstWwKeaPZ99eJkZPGWS82LNhT6Jbc,7756
|
|
10
10
|
torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
|
|
11
11
|
torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
|
|
@@ -60,7 +60,7 @@ torch_rechub/trainers/match_trainer.py,sha256=oASggXTvFd-93ltvt2uhB1TFPSYP_H-EGd
|
|
|
60
60
|
torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
|
|
61
61
|
torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
|
|
62
62
|
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
|
-
torch_rechub/utils/data.py,sha256=
|
|
63
|
+
torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
|
|
64
64
|
torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
|
|
65
65
|
torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
|
|
66
66
|
torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
|
|
@@ -68,7 +68,7 @@ torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,573
|
|
|
68
68
|
torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
|
|
69
69
|
torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
|
|
70
70
|
torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
|
|
71
|
-
torch_rechub-0.
|
|
72
|
-
torch_rechub-0.
|
|
73
|
-
torch_rechub-0.
|
|
74
|
-
torch_rechub-0.
|
|
71
|
+
torch_rechub-0.2.0.dist-info/METADATA,sha256=FGmR2swqnS6uViykJd4BFHyQ2d9itA42r4t0XXkPgq8,18098
|
|
72
|
+
torch_rechub-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
73
|
+
torch_rechub-0.2.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
74
|
+
torch_rechub-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|