torch-rechub 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -846,7 +846,7 @@ class HSTULayer(nn.Module):
846
846
  self.dropout = nn.Dropout(dropout)
847
847
 
848
848
  # Scaling factor for attention scores
849
- self.scale = 1.0 / (dqk**0.5)
849
+ # self.scale = 1.0 / (dqk**0.5) # Removed in favor of L2 norm + SiLU
850
850
 
851
851
  def forward(self, x, rel_pos_bias=None):
852
852
  """Forward pass of a single HSTU layer.
@@ -878,6 +878,10 @@ class HSTULayer(nn.Module):
878
878
  u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
879
879
  v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
880
880
 
881
+ # Apply L2 normalization to Q and K (HSTU specific)
882
+ q = F.normalize(q, p=2, dim=-1)
883
+ k = F.normalize(k, p=2, dim=-1)
884
+
881
885
  # Transpose to (B, H, L, dqk/dv)
882
886
  q = q.transpose(1, 2) # (B, H, L, dqk)
883
887
  k = k.transpose(1, 2) # (B, H, L, dqk)
@@ -885,20 +889,22 @@ class HSTULayer(nn.Module):
885
889
  v = v.transpose(1, 2) # (B, H, L, dv)
886
890
 
887
891
  # Compute attention scores: (B, H, L, L)
888
- scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
892
+ # Note: No scaling factor here as we use L2 norm + SiLU
893
+ scores = torch.matmul(q, k.transpose(-2, -1))
894
+
895
+ # Add relative position bias if provided (before masking/activation)
896
+ if rel_pos_bias is not None:
897
+ scores = scores + rel_pos_bias
889
898
 
890
899
  # Add causal mask (prevent attending to future positions)
891
900
  # For generative models this is required so that position i only attends
892
901
  # to positions <= i.
893
902
  causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
894
- scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
895
-
896
- # Add relative position bias if provided
897
- if rel_pos_bias is not None:
898
- scores = scores + rel_pos_bias
903
+ # Use a large negative number for masking compatible with SiLU
904
+ scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), -1e4)
899
905
 
900
- # Softmax over attention scores
901
- attn_weights = F.softmax(scores, dim=-1)
906
+ # SiLU activation over attention scores (HSTU specific)
907
+ attn_weights = F.silu(scores)
902
908
  attn_weights = self.dropout(attn_weights)
903
909
 
904
910
  # Attention output: (B, H, L, dv)
@@ -482,41 +482,57 @@ class SequenceDataGenerator(object):
482
482
  # Underlying dataset
483
483
  self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
484
484
 
485
- def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
486
- """Generate train/val/test dataloaders.
485
+ def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None, shuffle=True):
486
+ """Generate dataloader(s) from the dataset.
487
487
 
488
488
  Parameters
489
489
  ----------
490
490
  batch_size : int, default=32
491
+ Batch size for DataLoader.
491
492
  num_workers : int, default=0
492
- split_ratio : tuple, default (0.7, 0.1, 0.2)
493
- Train/val/test split.
493
+ Number of workers for DataLoader.
494
+ split_ratio : tuple or None, default=None
495
+ If None, returns a single DataLoader without splitting the data.
496
+ If tuple (e.g., (0.7, 0.1, 0.2)), splits dataset and returns
497
+ (train_loader, val_loader, test_loader).
498
+ shuffle : bool, default=True
499
+ Whether to shuffle data. Only applies when split_ratio is None.
500
+ When split_ratio is provided, train data is always shuffled.
494
501
 
495
502
  Returns
496
503
  -------
497
504
  tuple
498
- (train_loader, val_loader, test_loader)
505
+ If split_ratio is None: returns (dataloader,)
506
+ If split_ratio is provided: returns (train_loader, val_loader, test_loader)
507
+
508
+ Examples
509
+ --------
510
+ # Case 1: Data already split, just create loader
511
+ >>> train_gen = SequenceDataGenerator(train_data['seq_tokens'], ...)
512
+ >>> train_loader = train_gen.generate_dataloader(batch_size=32)[0]
513
+
514
+ # Case 2: Auto-split data into train/val/test
515
+ >>> all_gen = SequenceDataGenerator(all_data['seq_tokens'], ...)
516
+ >>> train_loader, val_loader, test_loader = all_gen.generate_dataloader(
517
+ ... batch_size=32, split_ratio=(0.7, 0.1, 0.2))
499
518
  """
500
519
  if split_ratio is None:
501
- split_ratio = (0.7, 0.1, 0.2)
520
+ # No split - data is already divided, just create a single DataLoader
521
+ dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
522
+ return (dataloader,)
502
523
 
503
- # 验证分割比例
524
+ # Split data into train/val/test
504
525
  assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
505
526
 
506
- # 计算分割大小
507
527
  total_size = len(self.dataset)
508
528
  train_size = int(total_size * split_ratio[0])
509
529
  val_size = int(total_size * split_ratio[1])
510
530
  test_size = total_size - train_size - val_size
511
531
 
512
- # 分割数据集
513
532
  train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
514
533
 
515
- # 创建数据加载器
516
534
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
517
-
518
535
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
519
-
520
536
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
521
537
 
522
538
  return train_loader, val_loader, test_loader
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-rechub
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
5
5
  Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
6
6
  Project-URL: Documentation, https://www.torch-rechub.com
@@ -31,7 +31,7 @@ Requires-Dist: transformers>=4.46.3
31
31
  Provides-Extra: annoy
32
32
  Requires-Dist: annoy>=1.17.2; extra == 'annoy'
33
33
  Provides-Extra: bigdata
34
- Requires-Dist: pyarrow~=21.0; extra == 'bigdata'
34
+ Requires-Dist: pyarrow<23,>=21; extra == 'bigdata'
35
35
  Provides-Extra: dev
36
36
  Requires-Dist: bandit>=1.7.0; extra == 'dev'
37
37
  Requires-Dist: flake8>=3.8.0; extra == 'dev'
@@ -60,9 +60,11 @@ Requires-Dist: graphviz>=0.20; extra == 'visualization'
60
60
  Requires-Dist: torchview>=0.2.6; extra == 'visualization'
61
61
  Description-Content-Type: text/markdown
62
62
 
63
- # 🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架
63
+ <div align="center">
64
64
 
65
- > 🚀 **30+ 主流推荐模型** | 🎯 **开箱即用** | 📦 **一键部署 ONNX** | 🤖 **支持生成式推荐 (HSTU/HLLM)**
65
+ ![Torch-RecHub 横幅](docs/public/img/banner.png)
66
+
67
+ # Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
66
68
 
67
69
  [![许可证](https://img.shields.io/badge/license-MIT-blue?style=for-the-badge)](LICENSE)
68
70
  ![GitHub Repo stars](https://img.shields.io/github/stars/datawhalechina/torch-rechub?style=for-the-badge)
@@ -78,21 +80,13 @@ Description-Content-Type: text/markdown
78
80
 
79
81
  [English](README_en.md) | 简体中文
80
82
 
81
- **在线文档:** https://datawhalechina.github.io/torch-rechub/ (英文)| https://datawhalechina.github.io/torch-rechub/zh/ (简体中文)
83
+ ![架构图](docs/public/img/project_framework.png)
82
84
 
83
- **Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
85
+ </div>
84
86
 
85
- ![Torch-RecHub 横幅](docs/public/img/banner.png)
87
+ **在线文档:** https://datawhalechina.github.io/torch-rechub/zh/
86
88
 
87
- ## 🎯 为什么选择 Torch-RecHub?
88
-
89
- | 特性 | Torch-RecHub | 其他框架 |
90
- | ------------- | --------------------------- | ---------- |
91
- | 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
92
- | 模型覆盖 | **30+** 主流模型 | 有限 |
93
- | 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
94
- | ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
95
- | 学习曲线 | 极低 | 陡峭 |
89
+ **Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
96
90
 
97
91
  ## ✨ 特性
98
92
 
@@ -109,7 +103,6 @@ Description-Content-Type: text/markdown
109
103
  ## 📖 目录
110
104
 
111
105
  - [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
112
- - [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
113
106
  - [✨ 特性](#-特性)
114
107
  - [📖 目录](#-目录)
115
108
  - [🔧 安装](#-安装)
@@ -221,6 +214,8 @@ torch-rechub/ # 根目录
221
214
 
222
215
  本框架目前支持 **30+** 主流推荐模型:
223
216
 
217
+ <details>
218
+
224
219
  ### 排序模型 (Ranking Models) - 13个
225
220
 
226
221
  | 模型 | 论文 | 简介 |
@@ -236,7 +231,11 @@ torch-rechub/ # 根目录
236
231
  | **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
237
232
  | **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
238
233
  | **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
239
- | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
234
+ | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
235
+ |
236
+ </details>
237
+
238
+ <details>
240
239
 
241
240
  ### 召回模型 (Matching Models) - 12个
242
241
 
@@ -253,6 +252,10 @@ torch-rechub/ # 根目录
253
252
  | **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
254
253
  | **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
255
254
 
255
+ </details>
256
+
257
+ <details>
258
+
256
259
  ### 多任务模型 (Multi-Task Models) - 5个
257
260
 
258
261
  | 模型 | 论文 | 简介 |
@@ -263,6 +266,10 @@ torch-rechub/ # 根目录
263
266
  | **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
264
267
  | **SharedBottom** | - | 经典多任务共享底层 |
265
268
 
269
+ </details>
270
+
271
+ <details>
272
+
266
273
  ### 生成式推荐 (Generative Recommendation) - 2个
267
274
 
268
275
  | 模型 | 论文 | 简介 |
@@ -270,6 +277,8 @@ torch-rechub/ # 根目录
270
277
  | **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
271
278
  | **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
272
279
 
280
+ </details>
281
+
273
282
  ## 📊 支持的数据集
274
283
 
275
284
  框架内置了对以下常见数据集格式的支持或提供了处理脚本:
@@ -5,7 +5,7 @@ torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4Gej
5
5
  torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
6
6
  torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
7
7
  torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
8
- torch_rechub/basic/layers.py,sha256=sLntNogvBu0QHm7riwyuJp_FbpbmPG26XeOyLs83Yu0,38813
8
+ torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
9
9
  torch_rechub/basic/loss_func.py,sha256=a-j1gan4eYUk5zstWwKeaPZ99eJkZPGWS82LNhT6Jbc,7756
10
10
  torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
11
11
  torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
@@ -60,7 +60,7 @@ torch_rechub/trainers/match_trainer.py,sha256=oASggXTvFd-93ltvt2uhB1TFPSYP_H-EGd
60
60
  torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
61
61
  torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
62
62
  torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
- torch_rechub/utils/data.py,sha256=TALy-nP9tqfz0DG2nMjBae5UZyBRvZIDX7zjGMnRqZ8,18542
63
+ torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
64
64
  torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
65
65
  torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
66
66
  torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
@@ -68,7 +68,7 @@ torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,573
68
68
  torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
69
69
  torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
70
70
  torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
71
- torch_rechub-0.1.0.dist-info/METADATA,sha256=r7xaaxaN7MYx2BJu96WGU72nHvOpwFE9CQmZSKBnRrk,18746
72
- torch_rechub-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
- torch_rechub-0.1.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
74
- torch_rechub-0.1.0.dist-info/RECORD,,
71
+ torch_rechub-0.2.0.dist-info/METADATA,sha256=FGmR2swqnS6uViykJd4BFHyQ2d9itA42r4t0XXkPgq8,18098
72
+ torch_rechub-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
+ torch_rechub-0.2.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
74
+ torch_rechub-0.2.0.dist-info/RECORD,,