rwkv-ops 0.1.0__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

@@ -7,6 +7,7 @@ backward gradient computation, and integration with PyTorch's autograd system.
7
7
  该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
8
8
  使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
9
9
  并集成了 PyTorch 的自动求导系统。
10
+
10
11
  """
11
12
 
12
13
  import warnings
@@ -15,11 +16,11 @@ from typing import Optional
15
16
  import torch
16
17
  import triton
17
18
 
18
- # 导入内核实现模块 / Import kernel implementation modules
19
19
  from .torch_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
20
20
  from .torch_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
21
21
  from .torch_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
22
22
  from .torch_kernel.chunk_h_fwd import chunk_dplr_fwd_h
23
+
23
24
  from .torch_kernel.chunk_o_bwd import (
24
25
  chunk_dplr_bwd_dAu,
25
26
  chunk_dplr_bwd_dv,
@@ -37,11 +38,6 @@ from .get_torch_devices_info import (
37
38
 
38
39
 
39
40
  def cast(x, dtype):
40
- """
41
- Cast tensor x to specified dtype if not already in that format.
42
-
43
- 如果张量 x 不是目标数据类型,则将其转换为目标类型。
44
- """
45
41
  if x is None or x.dtype == dtype:
46
42
  return x
47
43
  return x.to(dtype)
@@ -97,11 +93,11 @@ def chunk_dplr_fwd(
97
93
 
98
94
  del ge
99
95
 
100
- # Compute WY representation
96
+ # A_ab, A_ak, gi, ge torch.float32
97
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
101
98
  w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
102
99
 
103
100
  del A_ab, A_ak
104
-
105
101
  h, v_new, final_state = chunk_dplr_fwd_h(
106
102
  kg=kg,
107
103
  bg=bg,
@@ -164,6 +160,7 @@ def chunk_dplr_bwd(
164
160
  dgk (torch.Tensor): Gradient of log decays
165
161
  dh0 (torch.Tensor): Gradient of initial state
166
162
  """
163
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
167
164
  gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
168
165
  A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
169
166
  q=q,
@@ -183,6 +180,9 @@ def chunk_dplr_bwd(
183
180
  kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
184
181
  )
185
182
  del u
183
+ # ******* end of recomputation *******
184
+ # A_ak, A_ab_inv, gi, ge torch.float32
185
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
186
186
 
187
187
  dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
188
188
  v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
@@ -211,7 +211,7 @@ def chunk_dplr_bwd(
211
211
  do=do,
212
212
  h=h,
213
213
  dh=dh,
214
- dv=dv,
214
+ dv=dv_new,
215
215
  w=w,
216
216
  gk=gi,
217
217
  chunk_size=BT,
@@ -282,11 +282,6 @@ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
282
282
  output_final_state: bool = True,
283
283
  cu_seqlens: Optional[torch.LongTensor] = None,
284
284
  ):
285
- """
286
- Forward function with autograd support.
287
-
288
- 支持自动求导的前向函数。
289
- """
290
285
  chunk_size = 16
291
286
  o, final_state = chunk_dplr_fwd(
292
287
  q=q,
@@ -310,11 +305,6 @@ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
310
305
  @input_guard
311
306
  @autocast_custom_bwd
312
307
  def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
313
- """
314
- Backward function with autograd support.
315
-
316
- 支持自动求导的反向函数。
317
- """
318
308
  q, k, v, a, b, gk, initial_state = ctx.saved_tensors
319
309
  BT = ctx.chunk_size
320
310
  cu_seqlens = ctx.cu_seqlens
@@ -439,6 +429,7 @@ def chunk_rwkv7(
439
429
 
440
430
  RWKV-7 注意力机制的接口函数。
441
431
  """
432
+
442
433
  if w is not None:
443
434
  log_w = -torch.exp(w)
444
435
  else:
@@ -458,11 +449,6 @@ def chunk_rwkv7(
458
449
 
459
450
 
460
451
  def transpose_head(x, head_first):
461
- """
462
- Transpose between head-first and time-first formats.
463
-
464
- 在 head-first 和 time-first 格式之间转置。
465
- """
466
452
  if head_first:
467
453
  x = torch.permute(x, dims=(0, 2, 1, 3))
468
454
  out = cast(x, torch.bfloat16).contiguous()
@@ -480,11 +466,6 @@ def generalized_delta_rule(
480
466
  output_final_state: bool = True,
481
467
  head_first: bool = False,
482
468
  ):
483
- """
484
- Generalized delta rule attention interface.
485
-
486
- 泛化 Delta Rule 注意力机制接口。
487
- """
488
469
  dtype = r.dtype
489
470
  r = transpose_head(r, head_first)
490
471
  k = transpose_head(k, head_first)
@@ -492,30 +473,16 @@ def generalized_delta_rule(
492
473
  a = transpose_head(a, head_first)
493
474
  b = transpose_head(b, head_first)
494
475
  w = transpose_head(w, head_first)
495
- if w.device.type == "cuda":
496
- out, state = chunk_rwkv7(
497
- r=r,
498
- k=k,
499
- v=v,
500
- a=a,
501
- b=b,
502
- w=w,
503
- initial_state=initial_state,
504
- output_final_state=output_final_state,
505
- )
506
- else:
507
- from .native_keras_op import generalized_delta_rule
508
-
509
- out, state = generalized_delta_rule(
510
- r=r,
511
- k=k,
512
- v=v,
513
- a=a,
514
- b=b,
515
- w=w,
516
- initial_state=initial_state,
517
- output_final_state=output_final_state,
518
- )
476
+ out, state = chunk_rwkv7(
477
+ r=r,
478
+ k=k,
479
+ v=v,
480
+ a=a,
481
+ b=b,
482
+ w=w,
483
+ initial_state=initial_state,
484
+ output_final_state=output_final_state,
485
+ )
519
486
  out = transpose_head(out, head_first)
520
487
  if output_final_state:
521
488
  return out, cast(state, dtype)
@@ -0,0 +1,258 @@
1
+ Metadata-Version: 2.1
2
+ Name: rwkv-ops
3
+ Version: 0.2
4
+ Home-page: https://github.com/pass-lin/rwkv_ops
5
+ License: Apache 2.0
6
+ Keywords: rwkv implement for multi backend
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE.txt
20
+ Requires-Dist: keras
21
+
22
+ [English Document](ENREADME.md)
23
+
24
+ # RWKV OPS 项目
25
+
26
+ > 由于 RWKV 将持续迭代,核心算子会随之更新。
27
+ > 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。
28
+
29
+ ### 当前支持
30
+ | 算子类型 | 框架支持 |
31
+ |----------|----------|
32
+ | GPU 算子 | PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线) |
33
+ | 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
34
+
35
+ > 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
36
+ > 注意:本库依赖 `keras`。
37
+
38
+ ---
39
+
40
+ ## 安装
41
+
42
+ ```bash
43
+ pip install rwkv_ops
44
+ ```
45
+
46
+ ---
47
+
48
+ ## 环境变量
49
+
50
+ | 变量名 | 含义 | 取值 | 默认值 | 优先级 |
51
+ |---|---|---|---|---|
52
+ | `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
53
+ | `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
54
+ | `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | — | — |
55
+
56
+ > 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
57
+ > `native` 为原生算子,无 chunkwise,速度慢且显存高。
58
+
59
+ ---
60
+
61
+ ## rwkv7op 使用方法
62
+
63
+ ```python
64
+ from rwkv_ops import generalized_delta_rule # 或 from rwkv_ops import rwkv7_op,完全等价
65
+
66
+ def generalized_delta_rule(
67
+ r,
68
+ w,
69
+ k,
70
+ v,
71
+ a,
72
+ b,
73
+ initial_state=None,
74
+ output_final_state: bool = True,
75
+ head_first: bool = False,
76
+ ):
77
+ """
78
+ 分块 Delta Rule 注意力接口。
79
+
80
+ Args:
81
+ q: [B, T, H, K]
82
+ k: [B, T, H, K]
83
+ v: [B, T, H, V]
84
+ a: [B, T, H, K]
85
+ b: [B, T, H, K]
86
+ gk: [B, T, H, K] # decay term in log space!
87
+ initial_state: 初始状态 [N, H, K, V],N 为序列数
88
+ output_final_state: 是否返回最终状态
89
+ head_first: 是否 head-first 格式,不支持变长
90
+
91
+ Returns:
92
+ o: 输出 [B, T, H, V] 或 [B, H, T, V]
93
+ final_state: 最终状态 [N, H, K, V] 或 None
94
+ """
95
+ ```
96
+
97
+ ### torch-cuda 特殊用法
98
+
99
+ - torch-cuda 下 `head_size` 也是一个 kernel 参数,默认为 64。
100
+ - 若 `head_size ≠ 64`,请使用:
101
+
102
+ ```python
103
+ from rwkv_ops import get_generalized_delta_rule
104
+
105
+ generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
106
+ your_head_size, KERNEL_TYPE="cuda"
107
+ )
108
+ ```
109
+
110
+ - `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子。
111
+ - 两者 padding 处理逻辑不同:
112
+
113
+ ```python
114
+ if padding_mask is not None:
115
+ if RWKV7_USE_KERNEL:
116
+ w += (1 - padding_mask) * -1e9
117
+ else:
118
+ w = w * padding_mask + 1 - padding_mask
119
+ ```
120
+
121
+ ---
122
+
123
+ ### rwkv7op 实现状态
124
+
125
+ | Framework | cuda | triton | native |
126
+ |-------------|------|--------|--------|
127
+ | PyTorch | ✅ | ✅ | ✅ |
128
+ | JAX | ❌ | ✅ | ✅ |
129
+ | TensorFlow | ❌ | ❌ | ✅ |
130
+ | NumPy | ❌ | ❌ | ✅ |
131
+
132
+ ---
133
+
134
+ ## rwkv6op 使用方法
135
+
136
+ ### PyTorch 使用注意事项
137
+
138
+ - 安装依赖:`keras`、`ninja`、完整的 CUDA 工具包。
139
+ - 若使用 VS Code + 虚拟环境调试,请务必在终端手动激活虚拟环境,再运行代码,否则 ninja 可能无法工作。
140
+ - 虽然 PyTorch 在「虚拟环境中的 CUDA 版本」与「全局 CUDA 版本」不一致时仍可正常运行,但强烈建议保持一致。
141
+ - PyTorch 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
142
+
143
+ ### JAX 使用注意事项
144
+
145
+ - 安装依赖:`keras`、`gcc`、`pybind11`、完整的 CUDA 工具包。
146
+ - 即使通过虚拟环境为 JAX 安装 CUDA,也必须在系统级安装完整 CUDA;两者版本需一致,以保证 JAX 并行编译速度。
147
+ - JAX 编译依赖 `/usr/local/cuda` 软链接,如不存在请手动创建:
148
+ ```shell
149
+ sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda
150
+ ```
151
+ - 确保 `nvcc -V` 正常输出,且 `which nvcc` 指向正确版本。
152
+ - JAX 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
153
+ - JAX ≥ 0.6.0 不再使用 CUDA 算子,默认使用原生算子;推荐 0.4.34。
154
+
155
+ ### TensorFlow 使用注意事项
156
+
157
+ - 仅提供基于原生 API 的 `RWKV6` 算子,仅用于推理,效率较低。
158
+
159
+ ---
160
+
161
+ ### 使用方法
162
+ 需要注意的是,和rwkv7写成函数的形式不一样,RWKV6的op是一个类,需要实例化。
163
+ ```python
164
+ from rwkv_ops import RWKV6_OP
165
+
166
+ operator = RWKV6_OP(
167
+ head_size=64, # 头大小,不确定时填 64
168
+ max_sequence_length=4096, # 训练最大序列长度;推理不受限
169
+ ops_loop=False # 可选:序列长度=1 时是否用上层 API 替代 CUDA
170
+ )
171
+ ```
172
+
173
+ #### 调用
174
+
175
+ ```python
176
+ y, y_state = operator(
177
+ r, k, v, w, u,
178
+ with_state=False, # 是否使用自定义初始状态 / 输出结束状态
179
+ init_state=None, # 初始状态 [n_state, num_heads, head_size, head_size]
180
+ state_map=None # int32 一维数组,长度=batch_size,定义 init_state 映射
181
+ )
182
+ ```
183
+
184
+ | 参数 | 形状 | 说明 |
185
+ |---|---|---|
186
+ | r, k, v, w | (batch_size, seq_len, hidden_size) | — |
187
+ | u | (num_heads, head_size) 或 (hidden_size,) | — |
188
+ | init_state | (n_state, num_heads, head_size, head_size) | n_state=1 时所有样本共用;n_state=batch_size 时一一对应 |
189
+ | state_map | (batch_size,) | 指定每个样本用到的 init_state 索引 |
190
+
191
+ | 返回值 | 形状 | 说明 |
192
+ |---|---|---|
193
+ | y | (batch_size, seq_len, hidden_size) | 输出 |
194
+ | y_state | (batch_size, num_heads, head_size, head_size) 或 None | 结束状态 |
195
+
196
+ ---
197
+
198
+ ### 分布式小贴士
199
+
200
+ - 算子本身无分布式支持;PyTorch 可直接用多线程分布式。
201
+ - JAX 需通过 `shard_map` 包装(示例):
202
+
203
+ ```python
204
+ import os
205
+ os.environ['KERAS_BACKEND'] = 'jax'
206
+
207
+ import jax, jax.numpy as jnp
208
+ from jax.experimental.shard_map import shard_map
209
+ from jax.sharding import Mesh, PartitionSpec as P
210
+ from functools import partial
211
+ from rwkv_ops import RWKV6_OP
212
+
213
+ batch_size, seq_length = 24, 512
214
+ head_size, num_heads = 64, 32
215
+ hidden_size = head_size * num_heads
216
+
217
+ mesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))
218
+ device_ns = NamedSharding(mesh, P('device_axis'))
219
+
220
+ operator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)
221
+
222
+ @partial(shard_map,
223
+ mesh=mesh,
224
+ in_specs=(P('device_axis'),) * 5,
225
+ out_specs=(P('device_axis'), P('device_axis')),
226
+ check_rep=False)
227
+ def call_kernel(r, k, v, w, u):
228
+ # 去掉最外 device 维度
229
+ r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))
230
+ y, ys = operator(r, k, v, w, u, with_state=True)
231
+ return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)
232
+
233
+ # 构造输入并放置到对应设备
234
+ keys = jax.random.split(jax.random.PRNGKey(0), 5)
235
+ inputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]
236
+ inputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(
237
+ lambda x: jax.device_put(x, device_ns), inputs)
238
+ inputs_u = inputs_u[:, :, 0] # (devices, hidden_size)
239
+
240
+ # 可选:jax.jit(call_kernel, ...) 加速
241
+ outputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)
242
+
243
+ print(outputs_y.shape, outputs_y.sharding)
244
+ print(y_state.shape, y_state.sharding)
245
+ ```
246
+
247
+ ---
248
+
249
+ ### rwkv6op 实现状态
250
+
251
+ | Framework | cuda | triton | native |
252
+ |-------------|------|--------|--------|
253
+ | PyTorch | ✅ | ❌ | ✅ |
254
+ | JAX | ⚠️ | ❌ | ✅ |
255
+ | TensorFlow | ❌ | ❌ | ✅ |
256
+ | NumPy | ❌ | ❌ | ✅ |
257
+
258
+ ⚠️ JAX 的 CUDA 实现仅适用于 < 0.6.0,推荐 0.4.34。
@@ -1,16 +1,20 @@
1
- rwkv_ops/__init__.py,sha256=Qp3EFYuuQB19ND6yvwxbczvGgH98hZcWr8_boIf_5I8,725
2
- rwkv_ops/rwkv7_kernel/__init__.py,sha256=nG_loc60Vp_ZQGLvwWFJYUIAsNCt-ldUYJnWdmVwyPc,5681
3
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=fC_bISkSrKk1DvDNEaAdThmDZElz7T1BGWRl6EOXd6M,6005
4
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=3Sgzw_-9VElBecmQ3TbW4y8GI9QZfwlP6BMb7CovlQ8,7380
5
- rwkv_ops/rwkv7_kernel/jax_op.py,sha256=xOg2YbPPHkXEJ59sOo1_stOrghJQPuj2nhXrqw4GSM4,9109
1
+ rwkv_ops/__init__.py,sha256=Kfw_9iearbIplpuxx8sUl30TxKbTO-Ehqe0290Y4sZw,841
2
+ rwkv_ops/rwkv6_kernel/__init__.py,sha256=_j6G_3fY8xPxrlZbgDT2ndX4IPiNJ4qjqIcdmNI_r9Q,4100
3
+ rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=WOzqfQQSHHMoWqm2kRz_BhtMzGYc5USJ26qaEwuARo4,30117
4
+ rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=otjfw5n6nf2YVpBIWIZjaCsxMyLXXwg-ma1ueXX-EdY,3274
5
+ rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=Q1uPMgaS21OEfQ8-sBDjaCUASMtkSOdN3OosEUsBp9U,12918
6
+ rwkv_ops/rwkv7_kernel/__init__.py,sha256=GpwZ5dk7d5H6u-VSUAQ29KQVnCZdEwtKZ13_3kVJets,5888
7
+ rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
8
+ rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
9
+ rwkv_ops/rwkv7_kernel/jax_op.py,sha256=tyMxvk_EblDaGsePpxw3AhELvolp7LeE5NopUhKw1R0,9107
6
10
  rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=QPrXLbqw0chipQg_0jepRp2U19BYpBBFdKZWyaDNNoc,2488
7
- rwkv_ops/rwkv7_kernel/torch_op.py,sha256=elbv2SOHs9fzsO0FGxQ6ug7HVsjTZAE2oIM_ZOOY4JM,14451
11
+ rwkv_ops/rwkv7_kernel/torch_op.py,sha256=yY5QP87iDow-T6a4ZzFShyIQ8gprTQoLYcjFqgOTW4Y,13675
8
12
  rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
9
13
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
10
14
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
11
15
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=JDfVZsMb8yMlMN3sKT3i3l3y1YQiQkyUjnSNyan5Fqc,1888
12
16
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=g8b_81rIIjxeknYiklRGnox24rAvEvfKRKT-5nI0Euo,1992
13
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=W5gutV5WXA1bscfuDFCYNoRiG526CjHy4eaD2VtYuh0,3357
17
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=gQnToi1e1GZCvjWsEdWx6WakUN4Lc0JfaBSsSXYdN84,3369
14
18
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9EQ4vllvEx30EcvSZJzI,853
15
19
  rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
16
20
  rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py,sha256=PAMtE6wCW2Hz39oiHLGqhxY77csQAMYdNP2najDO_Jg,1407
@@ -20,7 +24,7 @@ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zK
20
24
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
21
25
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=zo6l0ZZUhXFu8wEFD76I0zSqFT9IXFKUKtyeaSwk380,1795
22
26
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=0ucN1U0EDTDqcyTPLLcsAX6FLTf2E_3toOY9p81gWYE,1858
23
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=CR6rK729Jfdt3JgZMpZxmALyTIYT6a7bmvllCm0GIxI,3281
27
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=ioPrS0NYQhpFk1j8rAxqtbwpx1CwjJQnrJEBDqVy-As,3283
24
28
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py,sha256=54yoa3NpV64H-koURt-hUWpFHhUjwXpGvXPp2_ETCnw,825
25
29
  rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py,sha256=hQkpyaa0eUyB4V3UVks7l1_dHwOrbump0FZILityBKw,611
26
30
  rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py,sha256=gk6QdoT1oq5B8Hp8Ak-SGqHm8CEj3MErUeWcRsaaOQM,1470
@@ -36,8 +40,8 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
36
40
  rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
37
41
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
38
42
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
39
- rwkv_ops-0.1.0.dist-info/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
40
- rwkv_ops-0.1.0.dist-info/METADATA,sha256=KRIeem9Ulrk0bwVyelirLyWXcGVLoa1BXmwgjUPRikk,3572
41
- rwkv_ops-0.1.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
42
- rwkv_ops-0.1.0.dist-info/top_level.txt,sha256=cVqoKE-WR_e2gHL87-6O4K1kG6-yTJGB2huyr6FmD2I,9
43
- rwkv_ops-0.1.0.dist-info/RECORD,,
43
+ rwkv_ops-0.2.dist-info/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
44
+ rwkv_ops-0.2.dist-info/METADATA,sha256=cUhC6EYLULLgNLVtLOg4qMSoAZjIDcrojax6egCTU04,8409
45
+ rwkv_ops-0.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
46
+ rwkv_ops-0.2.dist-info/top_level.txt,sha256=cVqoKE-WR_e2gHL87-6O4K1kG6-yTJGB2huyr6FmD2I,9
47
+ rwkv_ops-0.2.dist-info/RECORD,,
@@ -1,118 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: rwkv-ops
3
- Version: 0.1.0
4
- Home-page: https://github.com/your-org/rwkv_ops
5
- License: Apache 2.0
6
- Keywords: rwkv attention cuda triton pytorch jax
7
- Classifier: Development Status :: 3 - Alpha
8
- Classifier: Intended Audience :: Developers
9
- Classifier: Intended Audience :: Science/Research
10
- Classifier: License :: OSI Approved :: Apache Software License
11
- Classifier: Operating System :: OS Independent
12
- Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.8
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Programming Language :: Python :: 3.10
16
- Classifier: Programming Language :: Python :: 3.11
17
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
- Description-Content-Type: text/markdown
19
- License-File: LICENSE.txt
20
- Requires-Dist: keras
21
-
22
- [English Document](ENREADME.md)
23
- # RWKV OPS 项目
24
- > 由于 RWKV 将持续迭代,核心算子会随之更新。
25
- > 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。
26
- > 目前:
27
- > • GPU 算子:PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线)
28
- > • 原生算子:PyTorch、JAX、TensorFlow、NumPy
29
- > 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
30
- > 注意:本库依赖 `keras`。
31
-
32
- ---
33
-
34
- ## 环境变量
35
-
36
- | 变量名 | 含义 | 取值 | 默认值 | 优先级 |
37
- |---|---|---|---|---|
38
- | `KERAS_BACKEND` | Keras 后端 | jax / torch / tensorflow / numpy | — | 低 |
39
- | `KERNEL_BACKEND` | 算子后端 | jax / torch / tensorflow / numpy | torch | **高** |
40
- | `KERNEL_TYPE` | 实现类型 | triton / cuda / native | — | — |
41
-
42
- > 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 torch。
43
- > `native` 为原生算子,无 chunkwise,速度慢且显存高。
44
-
45
- ---
46
-
47
- ## rwkv7op 使用方法
48
-
49
- ```python
50
- from rwkv_ops import generalized_delta_rule # 或 from rwkv_ops import rwkv7_op,完全等价
51
-
52
- def generalized_delta_rule(
53
- r,
54
- w,
55
- k,
56
- v,
57
- a,
58
- b,
59
- initial_state=None,
60
- output_final_state: bool = True,
61
- head_first: bool = False,
62
- ):
63
- """
64
- 分块 Delta Rule 注意力接口。
65
-
66
- Args:
67
- q: [B, T, H, K]
68
- k: [B, T, H, K]
69
- v: [B, T, H, V]
70
- a: [B, T, H, K]
71
- b: [B, T, H, K]
72
- gk: [B, T, H, K] # decay term in log space!
73
- initial_state: 初始状态 [N, H, K, V],N 为序列数
74
- output_final_state: 是否返回最终状态
75
- head_first: 是否 head-first 格式,不支持变长
76
-
77
- Returns:
78
- o: 输出 [B, T, H, V] 或 [B, H, T, V]
79
- final_state: 最终状态 [N, H, K, V] 或 None
80
- """
81
- ```
82
-
83
- ---
84
-
85
-
86
- torch-cuda下head-size也是一个kernel参数,默认是64.
87
- 若 head-size ≠ 64,请使用:
88
-
89
- ```python
90
- from rwkv_ops import get_generalized_delta_rule
91
-
92
- generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
93
- your_head_size, KERNEL_TYPE="cuda"
94
- )
95
- ```
96
-
97
- `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子;
98
- 因为两者padding 处理逻辑不同,具体如下
99
-
100
- ```python
101
- if padding_mask is not None:
102
- if RWKV7_USE_KERNEL:
103
- w += (1 - padding_mask) * -1e9
104
- else:
105
- w = w * padding_mask + 1 - padding_mask
106
- ```
107
-
108
- ---
109
-
110
- ### rwkv7op的实现状态
111
-
112
-
113
- | Framework | cuda | triton | native |
114
- |-------------|------|--------|--------|
115
- | PyTorch | ✅ | ✅ | ✅ |
116
- | JAX | ❌ | ✅ | ✅ |
117
- | TensorFlow | ❌ | ❌ | ✅ |
118
- | NumPy | ❌ | ❌ | ✅ |