rwkv-ops 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +3 -1
- rwkv_ops/rwkv6_kernel/__init__.py +126 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +724 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +86 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +3 -7
- rwkv_ops/rwkv7_kernel/torch_op.py +91 -39
- rwkv_ops-0.2.1.dist-info/METADATA +258 -0
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.1.dist-info}/RECORD +12 -8
- rwkv_ops-0.1.1.dist-info/METADATA +0 -119
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.1.dist-info}/LICENSE.txt +0 -0
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.1.dist-info}/WHEEL +0 -0
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.1.dist-info}/top_level.txt +0 -0
rwkv_ops/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.2.1"
|
|
2
2
|
import os
|
|
3
3
|
|
|
4
4
|
KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "triton")
|
|
@@ -20,8 +20,10 @@ else:
|
|
|
20
20
|
assert KERNEL_TYPE in ["triton", "cuda", "native"]
|
|
21
21
|
assert BACKEND in ["torch", "jax", "numpy", "tensorflow"]
|
|
22
22
|
from .rwkv7_kernel import get_generalized_delta_rule
|
|
23
|
+
from .rwkv6_kernel import get_rwkv6_kernel
|
|
23
24
|
|
|
24
25
|
generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
|
|
25
26
|
KERNEL_TYPE=KERNEL_TYPE
|
|
26
27
|
)
|
|
27
28
|
rwkv7_op = generalized_delta_rule
|
|
29
|
+
RWKV6_OP = get_rwkv6_kernel(KERNEL_TYPE=KERNEL_TYPE)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# copy right from https://github.com/infiy-quine/RWKV6_Keras_Operator
|
|
2
|
+
import os
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
from distutils.util import strtobool
|
|
6
|
+
from packaging import version
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_rwkv6_kernel(KERNEL_TYPE="native"):
|
|
10
|
+
ops_kernel = True
|
|
11
|
+
if KERNEL_TYPE == "cuda":
|
|
12
|
+
if keras.config.backend() == "jax":
|
|
13
|
+
import jax
|
|
14
|
+
|
|
15
|
+
if version.parse(jax.__version__) < version.parse("0.6.0"):
|
|
16
|
+
from .jax_rwkv_kernel import RWKVKernelOperator as CudaOperator
|
|
17
|
+
|
|
18
|
+
ops_kernel = False
|
|
19
|
+
else:
|
|
20
|
+
print(
|
|
21
|
+
"The CUDA kernel of RWKV6 cannot be used in JAX > = 0.6 version. If you need to use the CUDA KERNEL of RWKV6 version, please downgrade the version. It is recommended to downgrade to 0.4.34"
|
|
22
|
+
)
|
|
23
|
+
print(
|
|
24
|
+
"RWKV6的CUDA kernel在JAX> = 0.6版本无法使用,如果需要使用RWKV6版本的CUDA KERNEL,请降级版本,建议降级到0.4.34"
|
|
25
|
+
)
|
|
26
|
+
CudaOperator = None
|
|
27
|
+
elif keras.config.backend() == "torch":
|
|
28
|
+
from .torch_rwkv_kernel import RWKVKernelOperator as CudaOperator
|
|
29
|
+
|
|
30
|
+
ops_kernel = False
|
|
31
|
+
else:
|
|
32
|
+
CudaOperator = None
|
|
33
|
+
else:
|
|
34
|
+
CudaOperator = None
|
|
35
|
+
from .ops_rwkv_kernel import RWKVKernelOperator as OpsOperator
|
|
36
|
+
|
|
37
|
+
class RWKVKernelOperator:
|
|
38
|
+
def __init__(self, head_size, max_sequence_length, ops_loop=False):
|
|
39
|
+
self.enbale_cuda = CudaOperator is not None
|
|
40
|
+
|
|
41
|
+
if self.enbale_cuda:
|
|
42
|
+
self.cuda_operator = CudaOperator(head_size, max_sequence_length)
|
|
43
|
+
|
|
44
|
+
self.ops_operator = OpsOperator(head_size, max_sequence_length)
|
|
45
|
+
|
|
46
|
+
self.ops_loop = ops_loop
|
|
47
|
+
|
|
48
|
+
def __call__(
|
|
49
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
50
|
+
):
|
|
51
|
+
seq_len = r.shape[1]
|
|
52
|
+
|
|
53
|
+
def call_parallel():
|
|
54
|
+
if self.enbale_cuda:
|
|
55
|
+
return self.cuda_operator(
|
|
56
|
+
r=r,
|
|
57
|
+
k=k,
|
|
58
|
+
v=v,
|
|
59
|
+
w=w,
|
|
60
|
+
u=u,
|
|
61
|
+
with_state=with_state,
|
|
62
|
+
init_state=init_state,
|
|
63
|
+
state_map=state_map,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
return self.ops_operator(
|
|
67
|
+
r=r,
|
|
68
|
+
k=k,
|
|
69
|
+
v=v,
|
|
70
|
+
w=w,
|
|
71
|
+
u=u,
|
|
72
|
+
with_state=with_state,
|
|
73
|
+
init_state=init_state,
|
|
74
|
+
state_map=state_map,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def call_one_step():
|
|
78
|
+
return self.ops_operator(
|
|
79
|
+
r=r,
|
|
80
|
+
k=k,
|
|
81
|
+
v=v,
|
|
82
|
+
w=w,
|
|
83
|
+
u=u,
|
|
84
|
+
with_state=with_state,
|
|
85
|
+
init_state=init_state,
|
|
86
|
+
state_map=state_map,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not self.ops_loop:
|
|
90
|
+
return ops.cond(
|
|
91
|
+
seq_len != 1 and not ops_kernel, call_parallel, call_one_step
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
return call_parallel()
|
|
95
|
+
|
|
96
|
+
return RWKVKernelOperator
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# from .ops_rwkv_kernal import RWKVKernelOperator as OPSKernelOperator
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
新增三个参数
|
|
104
|
+
return_state 布尔类型 是否返回最终的state,如果想自定义init_state也需要启用这个开关
|
|
105
|
+
|
|
106
|
+
init_state
|
|
107
|
+
当init_state省缺时,则使用全零初始化BatchSize维度上的状态。
|
|
108
|
+
形状: (state_kinds,num_heads,head_size, head_size), 其中state_kinds为小于等于Batch_Size的正整数
|
|
109
|
+
精度: 在r为fp16时 init_state为fp32 其余时候类型与r相同
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
state_map
|
|
113
|
+
形状: (Batch_Size,)
|
|
114
|
+
精度: int64, list[int]
|
|
115
|
+
这个数组定义了state到r上每个Batch维度切片间的映射关系
|
|
116
|
+
取值范围: [0, state_kinds)
|
|
117
|
+
|
|
118
|
+
返回:
|
|
119
|
+
output, output_state
|
|
120
|
+
|
|
121
|
+
def __call__(self,r, k, v, w, u, return_state=False, init_state=None, state_map=None):
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
"""
|