rwkv-ops 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (48) hide show
  1. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/PKG-INFO +3 -2
  2. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/README.md +2 -1
  3. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/__init__.py +1 -0
  4. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/__init__.py +8 -1
  5. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +2 -3
  6. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +3 -3
  7. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_op.py +0 -2
  8. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +1 -1
  9. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_op.py +48 -119
  10. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops.egg-info/PKG-INFO +3 -2
  11. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/setup.py +2 -2
  12. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/LICENSE.txt +0 -0
  13. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +2 -2
  14. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +0 -0
  15. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +0 -0
  16. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +0 -0
  17. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +0 -0
  18. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +0 -0
  19. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +0 -0
  20. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +0 -0
  21. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +0 -0
  22. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +0 -0
  23. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/native_keras_op.py +0 -0
  24. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +0 -0
  25. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +0 -0
  26. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +0 -0
  27. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +0 -0
  28. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +0 -0
  29. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +0 -0
  30. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +0 -0
  31. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +0 -0
  32. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +0 -0
  33. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +0 -0
  34. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +0 -0
  35. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +0 -0
  36. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +0 -0
  37. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +0 -0
  38. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +0 -0
  39. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +0 -0
  40. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +0 -0
  41. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +0 -0
  42. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +0 -0
  43. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +0 -0
  44. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops.egg-info/SOURCES.txt +0 -0
  45. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops.egg-info/dependency_links.txt +0 -0
  46. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops.egg-info/requires.txt +0 -0
  47. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/rwkv_ops.egg-info/top_level.txt +0 -0
  48. {rwkv_ops-0.1.0 → rwkv_ops-0.1.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rwkv_ops
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Home-page: https://github.com/your-org/rwkv_ops
5
5
  License: Apache 2.0
6
6
  Keywords: rwkv attention cuda triton pytorch jax
@@ -29,7 +29,8 @@ License-File: LICENSE.txt
29
29
  > 注意:本库依赖 `keras`。
30
30
 
31
31
  ---
32
-
32
+ ## 安装方法
33
+ pip install rwkv_ops
33
34
  ## 环境变量
34
35
 
35
36
  | 变量名 | 含义 | 取值 | 默认值 | 优先级 |
@@ -9,7 +9,8 @@
9
9
  > 注意:本库依赖 `keras`。
10
10
 
11
11
  ---
12
-
12
+ ## 安装方法
13
+ pip install rwkv_ops
13
14
  ## 环境变量
14
15
 
15
16
  | 变量名 | 含义 | 取值 | 默认值 | 优先级 |
@@ -1,3 +1,4 @@
1
+ __version__ = "0.1.1"
1
2
  import os
2
3
 
3
4
  KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "triton")
@@ -1,5 +1,6 @@
1
1
  import keras
2
2
  from distutils.util import strtobool
3
+ import os
3
4
  from keras import ops
4
5
 
5
6
 
@@ -19,11 +20,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
19
20
  from .torch_op import generalized_delta_rule
20
21
 
21
22
  USE_KERNEL = True
23
+
22
24
  elif KERNEL_TYPE.lower() == "cuda":
23
25
  CHUNK_LEN = 16
24
26
  USE_KERNEL = True
25
27
  from torch.utils.cpp_extension import load
26
- import os
27
28
 
28
29
  flags = [
29
30
  "-res-usage",
@@ -137,11 +138,17 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
137
138
  from jax.lib import xla_bridge
138
139
  import jax
139
140
  import os
141
+ import logging
140
142
 
143
+ logging.basicConfig(level=logging.ERROR)
144
+ os.environ["TRITON_LOG_LEVEL"] = "ERROR" # 只显示错误级别的日志
145
+ os.environ["TRITON_DISABLE_AUTOTUNE"] = "1" # 禁用自动调优日志
146
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 禁用自动调优日志
141
147
  if (
142
148
  xla_bridge.get_backend().platform == "gpu"
143
149
  and KERNEL_TYPE.lower() == "triton"
144
150
  ):
151
+ os.environ["JAX_LOG_COMPUTATION"] = "0"
145
152
  from .jax_op import generalized_delta_rule
146
153
 
147
154
  USE_KERNEL = True
@@ -5,6 +5,8 @@ import functools
5
5
  import triton
6
6
  import jax
7
7
  import jax.numpy as jnp
8
+ from enum import Enum
9
+ import contextlib
8
10
 
9
11
 
10
12
  @lru_cache(maxsize=None)
@@ -82,9 +84,6 @@ def is_triton_shared_mem_enough(
82
84
 
83
85
  device_capacity = is_triton_shared_mem_enough()
84
86
 
85
- from enum import Enum
86
- import contextlib
87
-
88
87
 
89
88
  def _cpu_device_warning():
90
89
  import warnings
@@ -1,5 +1,5 @@
1
1
  # -*- coding: utf-8 -*-
2
- # Copyright (c) 2023-2025,Qingwen Lin
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
3
 
4
4
  from typing import Tuple
5
5
 
@@ -7,7 +7,7 @@ import jax_triton as jt
7
7
  import jax
8
8
  import triton
9
9
 
10
- from ..get_torch_devices_info import check_shared_mem
10
+ from ..get_jax_devices_info import check_shared_mem
11
11
  from ..triton_kernel.chunk_o_bwd import *
12
12
 
13
13
 
@@ -104,7 +104,7 @@ def chunk_dplr_bwd_o(
104
104
  out_shape=out_shapes,
105
105
  grid=grid,
106
106
  )
107
- return dq, dk, dw, db, dgk_last
107
+ return (dq, dk, dw, db, dgk_last)
108
108
 
109
109
 
110
110
  def chunk_dplr_bwd_dAu(
@@ -223,7 +223,6 @@ def chunk_dplr_bwd(
223
223
 
224
224
  dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
225
225
  del A_qk
226
-
227
226
  dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
228
227
  k=kg,
229
228
  b=bg,
@@ -239,7 +238,6 @@ def chunk_dplr_bwd(
239
238
  scale=scale,
240
239
  )
241
240
  del v_new
242
-
243
241
  dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
244
242
  A_ab_inv=A_ab_inv,
245
243
  A_ak=A_ak,
@@ -104,7 +104,7 @@ def chunk_dplr_bwd_o(
104
104
  BK=BK,
105
105
  BV=BV,
106
106
  )
107
- return dq, dk, dw, db, dgk_last
107
+ return (dq, dk, dw, db, dgk_last)
108
108
 
109
109
 
110
110
  def chunk_dplr_bwd_dAu(
@@ -1,25 +1,14 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- This file implements the forward and backward pass of a chunked delta rule attention mechanism,
4
- optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
5
- backward gradient computation, and integration with PyTorch's autograd system.
6
-
7
- 该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
8
- 使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
9
- 并集成了 PyTorch 的自动求导系统。
10
- """
11
-
12
1
  import warnings
13
2
  from typing import Optional
14
3
 
15
4
  import torch
16
5
  import triton
17
6
 
18
- # 导入内核实现模块 / Import kernel implementation modules
19
7
  from .torch_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
20
8
  from .torch_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
21
9
  from .torch_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
22
10
  from .torch_kernel.chunk_h_fwd import chunk_dplr_fwd_h
11
+
23
12
  from .torch_kernel.chunk_o_bwd import (
24
13
  chunk_dplr_bwd_dAu,
25
14
  chunk_dplr_bwd_dv,
@@ -37,11 +26,6 @@ from .get_torch_devices_info import (
37
26
 
38
27
 
39
28
  def cast(x, dtype):
40
- """
41
- Cast tensor x to specified dtype if not already in that format.
42
-
43
- 如果张量 x 不是目标数据类型,则将其转换为目标类型。
44
- """
45
29
  if x is None or x.dtype == dtype:
46
30
  return x
47
31
  return x.to(dtype)
@@ -59,27 +43,6 @@ def chunk_dplr_fwd(
59
43
  output_final_state: bool = True,
60
44
  chunk_size: int = 16,
61
45
  ):
62
- """
63
- Forward pass of chunked delta rule attention.
64
-
65
- 分块 Delta Rule 注意力机制的前向传播。
66
-
67
- Args:
68
- q (torch.Tensor): Queries tensor [B, T, H, K]
69
- k (torch.Tensor): Keys tensor [B, T, H, K]
70
- v (torch.Tensor): Values tensor [B, T, H, V]
71
- a (torch.Tensor): Activations tensor [B, T, H, K]
72
- b (torch.Tensor): Betas tensor [B, T, H, K]
73
- gk (torch.Tensor): Log decay tensor [B, T, H, K]
74
- scale (float): Scale factor for attention scores
75
- initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
76
- output_final_state (bool): Whether to return final state
77
- chunk_size (int): Chunk size for processing
78
-
79
- Returns:
80
- o (torch.Tensor): Output tensor [B, T, H, V]
81
- final_state (Optional[torch.Tensor]): Final state if requested
82
- """
83
46
  T = q.shape[1]
84
47
  BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
85
48
  gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
@@ -97,11 +60,11 @@ def chunk_dplr_fwd(
97
60
 
98
61
  del ge
99
62
 
100
- # Compute WY representation
63
+ # A_ab, A_ak, gi, ge torch.float32
64
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
101
65
  w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
102
66
 
103
67
  del A_ab, A_ak
104
-
105
68
  h, v_new, final_state = chunk_dplr_fwd_h(
106
69
  kg=kg,
107
70
  bg=bg,
@@ -137,33 +100,7 @@ def chunk_dplr_bwd(
137
100
  dht,
138
101
  BT: int = 16,
139
102
  ):
140
- """
141
- Backward pass of chunked delta rule attention.
142
-
143
- 分块 Delta Rule 注意力机制的反向传播。
144
-
145
- Args:
146
- q (torch.Tensor): Queries tensor [B, T, H, K]
147
- k (torch.Tensor): Keys tensor [B, T, H, K]
148
- v (torch.Tensor): Values tensor [B, T, H, V]
149
- a (torch.Tensor): Activations tensor [B, T, H, K]
150
- b (torch.Tensor): Betas tensor [B, T, H, K]
151
- gk (torch.Tensor): Log decay tensor [B, T, H, K]
152
- initial_state (torch.Tensor): Initial state for recurrent processing
153
- scale (float): Scale factor for attention scores
154
- do (torch.Tensor): Gradient of outputs
155
- dht (torch.Tensor): Gradient of final hidden state
156
- BT (int): Chunk size for processing
157
-
158
- Returns:
159
- dq (torch.Tensor): Gradient of queries
160
- dk (torch.Tensor): Gradient of keys
161
- dv (torch.Tensor): Gradient of values
162
- da (torch.Tensor): Gradient of activations
163
- db (torch.Tensor): Gradient of betas
164
- dgk (torch.Tensor): Gradient of log decays
165
- dh0 (torch.Tensor): Gradient of initial state
166
- """
103
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
167
104
  gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
168
105
  A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
169
106
  q=q,
@@ -183,6 +120,9 @@ def chunk_dplr_bwd(
183
120
  kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
184
121
  )
185
122
  del u
123
+ # ******* end of recomputation *******
124
+ # A_ak, A_ab_inv, gi, ge torch.float32
125
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
186
126
 
187
127
  dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
188
128
  v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
@@ -211,7 +151,7 @@ def chunk_dplr_bwd(
211
151
  do=do,
212
152
  h=h,
213
153
  dh=dh,
214
- dv=dv,
154
+ dv=dv_new,
215
155
  w=w,
216
156
  gk=gi,
217
157
  chunk_size=BT,
@@ -282,11 +222,6 @@ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
282
222
  output_final_state: bool = True,
283
223
  cu_seqlens: Optional[torch.LongTensor] = None,
284
224
  ):
285
- """
286
- Forward function with autograd support.
287
-
288
- 支持自动求导的前向函数。
289
- """
290
225
  chunk_size = 16
291
226
  o, final_state = chunk_dplr_fwd(
292
227
  q=q,
@@ -310,11 +245,6 @@ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
310
245
  @input_guard
311
246
  @autocast_custom_bwd
312
247
  def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
313
- """
314
- Backward function with autograd support.
315
-
316
- 支持自动求导的反向函数。
317
- """
318
248
  q, k, v, a, b, gk, initial_state = ctx.saved_tensors
319
249
  BT = ctx.chunk_size
320
250
  cu_seqlens = ctx.cu_seqlens
@@ -349,10 +279,6 @@ def chunk_dplr_delta_rule(
349
279
  cu_seqlens: Optional[torch.LongTensor] = None,
350
280
  ):
351
281
  r"""
352
- Main interface function for chunked delta rule attention.
353
-
354
- 分块 Delta Rule 注意力机制的主要接口函数。
355
-
356
282
  Args:
357
283
  q (torch.Tensor):
358
284
  queries of shape `[B, T, H, K]`
@@ -435,10 +361,37 @@ def chunk_rwkv7(
435
361
  output_final_state: bool = True,
436
362
  ):
437
363
  """
438
- Interface function for RWKV-7 attention.
439
-
440
- RWKV-7 注意力机制的接口函数。
364
+ Args:
365
+ r (torch.Tensor):
366
+ r of shape `[B, H, T, K]` .
367
+ k (torch.Tensor):
368
+ k of shape `[B, H, T, K]` .
369
+ v (torch.Tensor):
370
+ v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
371
+ a (torch.Tensor):
372
+ a of shape `[B, H, T, K]` .
373
+ b (torch.Tensor):
374
+ b of shape `[B, H, T, K]` .
375
+ w (torch.Tensor):
376
+ decay of shape `[B, H, T, K]` , kernel
377
+ will apply log_w = -torch.exp(w)
378
+ log_w (torch.Tensor):
379
+ log decay of shape `[B, H, T, K]` .
380
+ scale (float):
381
+ scale of the attention.
382
+ initial_state (Optional[torch.Tensor]):
383
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
384
+ For equal-length input sequences, `N` equals the batch size `B`.
385
+ Default: `None`.
386
+ output_final_state (Optional[bool]):
387
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
388
+ cu_seqlens (torch.LongTensor):
389
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
390
+ consistent with the FlashAttention API.
391
+ head_first (bool):
392
+ whether to use head first. Recommended to be False to avoid extra transposes.
441
393
  """
394
+
442
395
  if w is not None:
443
396
  log_w = -torch.exp(w)
444
397
  else:
@@ -458,11 +411,6 @@ def chunk_rwkv7(
458
411
 
459
412
 
460
413
  def transpose_head(x, head_first):
461
- """
462
- Transpose between head-first and time-first formats.
463
-
464
- 在 head-first 和 time-first 格式之间转置。
465
- """
466
414
  if head_first:
467
415
  x = torch.permute(x, dims=(0, 2, 1, 3))
468
416
  out = cast(x, torch.bfloat16).contiguous()
@@ -480,11 +428,6 @@ def generalized_delta_rule(
480
428
  output_final_state: bool = True,
481
429
  head_first: bool = False,
482
430
  ):
483
- """
484
- Generalized delta rule attention interface.
485
-
486
- 泛化 Delta Rule 注意力机制接口。
487
- """
488
431
  dtype = r.dtype
489
432
  r = transpose_head(r, head_first)
490
433
  k = transpose_head(k, head_first)
@@ -492,30 +435,16 @@ def generalized_delta_rule(
492
435
  a = transpose_head(a, head_first)
493
436
  b = transpose_head(b, head_first)
494
437
  w = transpose_head(w, head_first)
495
- if w.device.type == "cuda":
496
- out, state = chunk_rwkv7(
497
- r=r,
498
- k=k,
499
- v=v,
500
- a=a,
501
- b=b,
502
- w=w,
503
- initial_state=initial_state,
504
- output_final_state=output_final_state,
505
- )
506
- else:
507
- from .native_keras_op import generalized_delta_rule
508
-
509
- out, state = generalized_delta_rule(
510
- r=r,
511
- k=k,
512
- v=v,
513
- a=a,
514
- b=b,
515
- w=w,
516
- initial_state=initial_state,
517
- output_final_state=output_final_state,
518
- )
438
+ out, state = chunk_rwkv7(
439
+ r=r,
440
+ k=k,
441
+ v=v,
442
+ a=a,
443
+ b=b,
444
+ w=w,
445
+ initial_state=initial_state,
446
+ output_final_state=output_final_state,
447
+ )
519
448
  out = transpose_head(out, head_first)
520
449
  if output_final_state:
521
450
  return out, cast(state, dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rwkv-ops
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Home-page: https://github.com/your-org/rwkv_ops
5
5
  License: Apache 2.0
6
6
  Keywords: rwkv attention cuda triton pytorch jax
@@ -29,7 +29,8 @@ License-File: LICENSE.txt
29
29
  > 注意:本库依赖 `keras`。
30
30
 
31
31
  ---
32
-
32
+ ## 安装方法
33
+ pip install rwkv_ops
33
34
  ## 环境变量
34
35
 
35
36
  | 变量名 | 含义 | 取值 | 默认值 | 优先级 |
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name="rwkv_ops",
6
- version="0.1.0",
6
+ version="0.1.1",
7
7
  packages=find_packages(),
8
8
  install_requires=["keras"], # 添加依赖项
9
9
  license="Apache 2.0", # 指定许可证类型
@@ -24,4 +24,4 @@ setup(
24
24
  "Programming Language :: Python :: 3.11",
25
25
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
26
  ],
27
- )
27
+ )
File without changes
@@ -6,6 +6,8 @@ from typing import Literal
6
6
  import triton
7
7
  from packaging import version
8
8
  import torch
9
+ from enum import Enum
10
+ import contextlib
9
11
 
10
12
 
11
13
  @lru_cache(maxsize=None)
@@ -105,8 +107,6 @@ def is_triton_shared_mem_enough(
105
107
 
106
108
 
107
109
  device_capacity = is_triton_shared_mem_enough()
108
- from enum import Enum
109
- import contextlib
110
110
 
111
111
 
112
112
  def _cpu_device_warning():
File without changes