rwkv-ops 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +1 -0
- rwkv_ops/rwkv7_kernel/__init__.py +8 -1
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +2 -3
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +2 -2
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +3 -3
- rwkv_ops/rwkv7_kernel/jax_op.py +0 -2
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +1 -1
- rwkv_ops/rwkv7_kernel/torch_op.py +48 -119
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.1.1.dist-info}/METADATA +3 -2
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.1.1.dist-info}/RECORD +13 -13
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.1.1.dist-info}/LICENSE.txt +0 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.1.1.dist-info}/WHEEL +0 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.1.1.dist-info}/top_level.txt +0 -0
rwkv_ops/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from distutils.util import strtobool
|
|
3
|
+
import os
|
|
3
4
|
from keras import ops
|
|
4
5
|
|
|
5
6
|
|
|
@@ -19,11 +20,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
19
20
|
from .torch_op import generalized_delta_rule
|
|
20
21
|
|
|
21
22
|
USE_KERNEL = True
|
|
23
|
+
|
|
22
24
|
elif KERNEL_TYPE.lower() == "cuda":
|
|
23
25
|
CHUNK_LEN = 16
|
|
24
26
|
USE_KERNEL = True
|
|
25
27
|
from torch.utils.cpp_extension import load
|
|
26
|
-
import os
|
|
27
28
|
|
|
28
29
|
flags = [
|
|
29
30
|
"-res-usage",
|
|
@@ -137,11 +138,17 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
137
138
|
from jax.lib import xla_bridge
|
|
138
139
|
import jax
|
|
139
140
|
import os
|
|
141
|
+
import logging
|
|
140
142
|
|
|
143
|
+
logging.basicConfig(level=logging.ERROR)
|
|
144
|
+
os.environ["TRITON_LOG_LEVEL"] = "ERROR" # 只显示错误级别的日志
|
|
145
|
+
os.environ["TRITON_DISABLE_AUTOTUNE"] = "1" # 禁用自动调优日志
|
|
146
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 禁用自动调优日志
|
|
141
147
|
if (
|
|
142
148
|
xla_bridge.get_backend().platform == "gpu"
|
|
143
149
|
and KERNEL_TYPE.lower() == "triton"
|
|
144
150
|
):
|
|
151
|
+
os.environ["JAX_LOG_COMPUTATION"] = "0"
|
|
145
152
|
from .jax_op import generalized_delta_rule
|
|
146
153
|
|
|
147
154
|
USE_KERNEL = True
|
|
@@ -5,6 +5,8 @@ import functools
|
|
|
5
5
|
import triton
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
|
+
from enum import Enum
|
|
9
|
+
import contextlib
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
@lru_cache(maxsize=None)
|
|
@@ -82,9 +84,6 @@ def is_triton_shared_mem_enough(
|
|
|
82
84
|
|
|
83
85
|
device_capacity = is_triton_shared_mem_enough()
|
|
84
86
|
|
|
85
|
-
from enum import Enum
|
|
86
|
-
import contextlib
|
|
87
|
-
|
|
88
87
|
|
|
89
88
|
def _cpu_device_warning():
|
|
90
89
|
import warnings
|
|
@@ -6,6 +6,8 @@ from typing import Literal
|
|
|
6
6
|
import triton
|
|
7
7
|
from packaging import version
|
|
8
8
|
import torch
|
|
9
|
+
from enum import Enum
|
|
10
|
+
import contextlib
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
@lru_cache(maxsize=None)
|
|
@@ -105,8 +107,6 @@ def is_triton_shared_mem_enough(
|
|
|
105
107
|
|
|
106
108
|
|
|
107
109
|
device_capacity = is_triton_shared_mem_enough()
|
|
108
|
-
from enum import Enum
|
|
109
|
-
import contextlib
|
|
110
110
|
|
|
111
111
|
|
|
112
112
|
def _cpu_device_warning():
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
|
-
# Copyright (c) 2023-2025,
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
3
|
|
|
4
4
|
from typing import Tuple
|
|
5
5
|
|
|
@@ -7,7 +7,7 @@ import jax_triton as jt
|
|
|
7
7
|
import jax
|
|
8
8
|
import triton
|
|
9
9
|
|
|
10
|
-
from ..
|
|
10
|
+
from ..get_jax_devices_info import check_shared_mem
|
|
11
11
|
from ..triton_kernel.chunk_o_bwd import *
|
|
12
12
|
|
|
13
13
|
|
|
@@ -104,7 +104,7 @@ def chunk_dplr_bwd_o(
|
|
|
104
104
|
out_shape=out_shapes,
|
|
105
105
|
grid=grid,
|
|
106
106
|
)
|
|
107
|
-
return dq, dk, dw, db, dgk_last
|
|
107
|
+
return (dq, dk, dw, db, dgk_last)
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
def chunk_dplr_bwd_dAu(
|
rwkv_ops/rwkv7_kernel/jax_op.py
CHANGED
|
@@ -223,7 +223,6 @@ def chunk_dplr_bwd(
|
|
|
223
223
|
|
|
224
224
|
dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
|
|
225
225
|
del A_qk
|
|
226
|
-
|
|
227
226
|
dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
|
|
228
227
|
k=kg,
|
|
229
228
|
b=bg,
|
|
@@ -239,7 +238,6 @@ def chunk_dplr_bwd(
|
|
|
239
238
|
scale=scale,
|
|
240
239
|
)
|
|
241
240
|
del v_new
|
|
242
|
-
|
|
243
241
|
dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
|
|
244
242
|
A_ab_inv=A_ab_inv,
|
|
245
243
|
A_ak=A_ak,
|
|
@@ -1,25 +1,14 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
"""
|
|
3
|
-
This file implements the forward and backward pass of a chunked delta rule attention mechanism,
|
|
4
|
-
optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
|
|
5
|
-
backward gradient computation, and integration with PyTorch's autograd system.
|
|
6
|
-
|
|
7
|
-
该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
|
|
8
|
-
使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
|
|
9
|
-
并集成了 PyTorch 的自动求导系统。
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
1
|
import warnings
|
|
13
2
|
from typing import Optional
|
|
14
3
|
|
|
15
4
|
import torch
|
|
16
5
|
import triton
|
|
17
6
|
|
|
18
|
-
# 导入内核实现模块 / Import kernel implementation modules
|
|
19
7
|
from .torch_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
|
|
20
8
|
from .torch_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
|
|
21
9
|
from .torch_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
|
|
22
10
|
from .torch_kernel.chunk_h_fwd import chunk_dplr_fwd_h
|
|
11
|
+
|
|
23
12
|
from .torch_kernel.chunk_o_bwd import (
|
|
24
13
|
chunk_dplr_bwd_dAu,
|
|
25
14
|
chunk_dplr_bwd_dv,
|
|
@@ -37,11 +26,6 @@ from .get_torch_devices_info import (
|
|
|
37
26
|
|
|
38
27
|
|
|
39
28
|
def cast(x, dtype):
|
|
40
|
-
"""
|
|
41
|
-
Cast tensor x to specified dtype if not already in that format.
|
|
42
|
-
|
|
43
|
-
如果张量 x 不是目标数据类型,则将其转换为目标类型。
|
|
44
|
-
"""
|
|
45
29
|
if x is None or x.dtype == dtype:
|
|
46
30
|
return x
|
|
47
31
|
return x.to(dtype)
|
|
@@ -59,27 +43,6 @@ def chunk_dplr_fwd(
|
|
|
59
43
|
output_final_state: bool = True,
|
|
60
44
|
chunk_size: int = 16,
|
|
61
45
|
):
|
|
62
|
-
"""
|
|
63
|
-
Forward pass of chunked delta rule attention.
|
|
64
|
-
|
|
65
|
-
分块 Delta Rule 注意力机制的前向传播。
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
q (torch.Tensor): Queries tensor [B, T, H, K]
|
|
69
|
-
k (torch.Tensor): Keys tensor [B, T, H, K]
|
|
70
|
-
v (torch.Tensor): Values tensor [B, T, H, V]
|
|
71
|
-
a (torch.Tensor): Activations tensor [B, T, H, K]
|
|
72
|
-
b (torch.Tensor): Betas tensor [B, T, H, K]
|
|
73
|
-
gk (torch.Tensor): Log decay tensor [B, T, H, K]
|
|
74
|
-
scale (float): Scale factor for attention scores
|
|
75
|
-
initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
|
|
76
|
-
output_final_state (bool): Whether to return final state
|
|
77
|
-
chunk_size (int): Chunk size for processing
|
|
78
|
-
|
|
79
|
-
Returns:
|
|
80
|
-
o (torch.Tensor): Output tensor [B, T, H, V]
|
|
81
|
-
final_state (Optional[torch.Tensor]): Final state if requested
|
|
82
|
-
"""
|
|
83
46
|
T = q.shape[1]
|
|
84
47
|
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
85
48
|
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
@@ -97,11 +60,11 @@ def chunk_dplr_fwd(
|
|
|
97
60
|
|
|
98
61
|
del ge
|
|
99
62
|
|
|
100
|
-
#
|
|
63
|
+
# A_ab, A_ak, gi, ge torch.float32
|
|
64
|
+
# A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
|
|
101
65
|
w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
|
|
102
66
|
|
|
103
67
|
del A_ab, A_ak
|
|
104
|
-
|
|
105
68
|
h, v_new, final_state = chunk_dplr_fwd_h(
|
|
106
69
|
kg=kg,
|
|
107
70
|
bg=bg,
|
|
@@ -137,33 +100,7 @@ def chunk_dplr_bwd(
|
|
|
137
100
|
dht,
|
|
138
101
|
BT: int = 16,
|
|
139
102
|
):
|
|
140
|
-
|
|
141
|
-
Backward pass of chunked delta rule attention.
|
|
142
|
-
|
|
143
|
-
分块 Delta Rule 注意力机制的反向传播。
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
q (torch.Tensor): Queries tensor [B, T, H, K]
|
|
147
|
-
k (torch.Tensor): Keys tensor [B, T, H, K]
|
|
148
|
-
v (torch.Tensor): Values tensor [B, T, H, V]
|
|
149
|
-
a (torch.Tensor): Activations tensor [B, T, H, K]
|
|
150
|
-
b (torch.Tensor): Betas tensor [B, T, H, K]
|
|
151
|
-
gk (torch.Tensor): Log decay tensor [B, T, H, K]
|
|
152
|
-
initial_state (torch.Tensor): Initial state for recurrent processing
|
|
153
|
-
scale (float): Scale factor for attention scores
|
|
154
|
-
do (torch.Tensor): Gradient of outputs
|
|
155
|
-
dht (torch.Tensor): Gradient of final hidden state
|
|
156
|
-
BT (int): Chunk size for processing
|
|
157
|
-
|
|
158
|
-
Returns:
|
|
159
|
-
dq (torch.Tensor): Gradient of queries
|
|
160
|
-
dk (torch.Tensor): Gradient of keys
|
|
161
|
-
dv (torch.Tensor): Gradient of values
|
|
162
|
-
da (torch.Tensor): Gradient of activations
|
|
163
|
-
db (torch.Tensor): Gradient of betas
|
|
164
|
-
dgk (torch.Tensor): Gradient of log decays
|
|
165
|
-
dh0 (torch.Tensor): Gradient of initial state
|
|
166
|
-
"""
|
|
103
|
+
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
|
|
167
104
|
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
168
105
|
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
169
106
|
q=q,
|
|
@@ -183,6 +120,9 @@ def chunk_dplr_bwd(
|
|
|
183
120
|
kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
|
|
184
121
|
)
|
|
185
122
|
del u
|
|
123
|
+
# ******* end of recomputation *******
|
|
124
|
+
# A_ak, A_ab_inv, gi, ge torch.float32
|
|
125
|
+
# A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
|
|
186
126
|
|
|
187
127
|
dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
|
|
188
128
|
v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
|
|
@@ -211,7 +151,7 @@ def chunk_dplr_bwd(
|
|
|
211
151
|
do=do,
|
|
212
152
|
h=h,
|
|
213
153
|
dh=dh,
|
|
214
|
-
dv=
|
|
154
|
+
dv=dv_new,
|
|
215
155
|
w=w,
|
|
216
156
|
gk=gi,
|
|
217
157
|
chunk_size=BT,
|
|
@@ -282,11 +222,6 @@ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
|
|
|
282
222
|
output_final_state: bool = True,
|
|
283
223
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
284
224
|
):
|
|
285
|
-
"""
|
|
286
|
-
Forward function with autograd support.
|
|
287
|
-
|
|
288
|
-
支持自动求导的前向函数。
|
|
289
|
-
"""
|
|
290
225
|
chunk_size = 16
|
|
291
226
|
o, final_state = chunk_dplr_fwd(
|
|
292
227
|
q=q,
|
|
@@ -310,11 +245,6 @@ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
|
|
|
310
245
|
@input_guard
|
|
311
246
|
@autocast_custom_bwd
|
|
312
247
|
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
|
|
313
|
-
"""
|
|
314
|
-
Backward function with autograd support.
|
|
315
|
-
|
|
316
|
-
支持自动求导的反向函数。
|
|
317
|
-
"""
|
|
318
248
|
q, k, v, a, b, gk, initial_state = ctx.saved_tensors
|
|
319
249
|
BT = ctx.chunk_size
|
|
320
250
|
cu_seqlens = ctx.cu_seqlens
|
|
@@ -349,10 +279,6 @@ def chunk_dplr_delta_rule(
|
|
|
349
279
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
350
280
|
):
|
|
351
281
|
r"""
|
|
352
|
-
Main interface function for chunked delta rule attention.
|
|
353
|
-
|
|
354
|
-
分块 Delta Rule 注意力机制的主要接口函数。
|
|
355
|
-
|
|
356
282
|
Args:
|
|
357
283
|
q (torch.Tensor):
|
|
358
284
|
queries of shape `[B, T, H, K]`
|
|
@@ -435,10 +361,37 @@ def chunk_rwkv7(
|
|
|
435
361
|
output_final_state: bool = True,
|
|
436
362
|
):
|
|
437
363
|
"""
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
364
|
+
Args:
|
|
365
|
+
r (torch.Tensor):
|
|
366
|
+
r of shape `[B, H, T, K]` .
|
|
367
|
+
k (torch.Tensor):
|
|
368
|
+
k of shape `[B, H, T, K]` .
|
|
369
|
+
v (torch.Tensor):
|
|
370
|
+
v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
|
|
371
|
+
a (torch.Tensor):
|
|
372
|
+
a of shape `[B, H, T, K]` .
|
|
373
|
+
b (torch.Tensor):
|
|
374
|
+
b of shape `[B, H, T, K]` .
|
|
375
|
+
w (torch.Tensor):
|
|
376
|
+
decay of shape `[B, H, T, K]` , kernel
|
|
377
|
+
will apply log_w = -torch.exp(w)
|
|
378
|
+
log_w (torch.Tensor):
|
|
379
|
+
log decay of shape `[B, H, T, K]` .
|
|
380
|
+
scale (float):
|
|
381
|
+
scale of the attention.
|
|
382
|
+
initial_state (Optional[torch.Tensor]):
|
|
383
|
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
|
384
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
|
385
|
+
Default: `None`.
|
|
386
|
+
output_final_state (Optional[bool]):
|
|
387
|
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
|
388
|
+
cu_seqlens (torch.LongTensor):
|
|
389
|
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
|
390
|
+
consistent with the FlashAttention API.
|
|
391
|
+
head_first (bool):
|
|
392
|
+
whether to use head first. Recommended to be False to avoid extra transposes.
|
|
441
393
|
"""
|
|
394
|
+
|
|
442
395
|
if w is not None:
|
|
443
396
|
log_w = -torch.exp(w)
|
|
444
397
|
else:
|
|
@@ -458,11 +411,6 @@ def chunk_rwkv7(
|
|
|
458
411
|
|
|
459
412
|
|
|
460
413
|
def transpose_head(x, head_first):
|
|
461
|
-
"""
|
|
462
|
-
Transpose between head-first and time-first formats.
|
|
463
|
-
|
|
464
|
-
在 head-first 和 time-first 格式之间转置。
|
|
465
|
-
"""
|
|
466
414
|
if head_first:
|
|
467
415
|
x = torch.permute(x, dims=(0, 2, 1, 3))
|
|
468
416
|
out = cast(x, torch.bfloat16).contiguous()
|
|
@@ -480,11 +428,6 @@ def generalized_delta_rule(
|
|
|
480
428
|
output_final_state: bool = True,
|
|
481
429
|
head_first: bool = False,
|
|
482
430
|
):
|
|
483
|
-
"""
|
|
484
|
-
Generalized delta rule attention interface.
|
|
485
|
-
|
|
486
|
-
泛化 Delta Rule 注意力机制接口。
|
|
487
|
-
"""
|
|
488
431
|
dtype = r.dtype
|
|
489
432
|
r = transpose_head(r, head_first)
|
|
490
433
|
k = transpose_head(k, head_first)
|
|
@@ -492,30 +435,16 @@ def generalized_delta_rule(
|
|
|
492
435
|
a = transpose_head(a, head_first)
|
|
493
436
|
b = transpose_head(b, head_first)
|
|
494
437
|
w = transpose_head(w, head_first)
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
)
|
|
506
|
-
else:
|
|
507
|
-
from .native_keras_op import generalized_delta_rule
|
|
508
|
-
|
|
509
|
-
out, state = generalized_delta_rule(
|
|
510
|
-
r=r,
|
|
511
|
-
k=k,
|
|
512
|
-
v=v,
|
|
513
|
-
a=a,
|
|
514
|
-
b=b,
|
|
515
|
-
w=w,
|
|
516
|
-
initial_state=initial_state,
|
|
517
|
-
output_final_state=output_final_state,
|
|
518
|
-
)
|
|
438
|
+
out, state = chunk_rwkv7(
|
|
439
|
+
r=r,
|
|
440
|
+
k=k,
|
|
441
|
+
v=v,
|
|
442
|
+
a=a,
|
|
443
|
+
b=b,
|
|
444
|
+
w=w,
|
|
445
|
+
initial_state=initial_state,
|
|
446
|
+
output_final_state=output_final_state,
|
|
447
|
+
)
|
|
519
448
|
out = transpose_head(out, head_first)
|
|
520
449
|
if output_final_state:
|
|
521
450
|
return out, cast(state, dtype)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: rwkv-ops
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Home-page: https://github.com/your-org/rwkv_ops
|
|
5
5
|
License: Apache 2.0
|
|
6
6
|
Keywords: rwkv attention cuda triton pytorch jax
|
|
@@ -30,7 +30,8 @@ Requires-Dist: keras
|
|
|
30
30
|
> 注意:本库依赖 `keras`。
|
|
31
31
|
|
|
32
32
|
---
|
|
33
|
-
|
|
33
|
+
## 安装方法
|
|
34
|
+
pip install rwkv_ops
|
|
34
35
|
## 环境变量
|
|
35
36
|
|
|
36
37
|
| 变量名 | 含义 | 取值 | 默认值 | 优先级 |
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
rwkv_ops/__init__.py,sha256=
|
|
2
|
-
rwkv_ops/rwkv7_kernel/__init__.py,sha256=
|
|
3
|
-
rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=
|
|
4
|
-
rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=
|
|
5
|
-
rwkv_ops/rwkv7_kernel/jax_op.py,sha256=
|
|
1
|
+
rwkv_ops/__init__.py,sha256=zhiKsTn4RGCGGy_0VIZUgHjPZK9XlEJHy39bNkwPnH8,747
|
|
2
|
+
rwkv_ops/rwkv7_kernel/__init__.py,sha256=k88BFK_NtUhG_27rK0_b48JCuEEXMb9_L9jGE50astc,6034
|
|
3
|
+
rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
|
|
4
|
+
rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
|
|
5
|
+
rwkv_ops/rwkv7_kernel/jax_op.py,sha256=tyMxvk_EblDaGsePpxw3AhELvolp7LeE5NopUhKw1R0,9107
|
|
6
6
|
rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=QPrXLbqw0chipQg_0jepRp2U19BYpBBFdKZWyaDNNoc,2488
|
|
7
|
-
rwkv_ops/rwkv7_kernel/torch_op.py,sha256=
|
|
7
|
+
rwkv_ops/rwkv7_kernel/torch_op.py,sha256=d6VQM7SS5ynQ_YTmqHzDIn2MLiXkYcMiSJD2eXEkTSg,12277
|
|
8
8
|
rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
|
|
9
9
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
|
|
10
10
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
|
|
11
11
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=JDfVZsMb8yMlMN3sKT3i3l3y1YQiQkyUjnSNyan5Fqc,1888
|
|
12
12
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=g8b_81rIIjxeknYiklRGnox24rAvEvfKRKT-5nI0Euo,1992
|
|
13
|
-
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=
|
|
13
|
+
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=gQnToi1e1GZCvjWsEdWx6WakUN4Lc0JfaBSsSXYdN84,3369
|
|
14
14
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9EQ4vllvEx30EcvSZJzI,853
|
|
15
15
|
rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
|
|
16
16
|
rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py,sha256=PAMtE6wCW2Hz39oiHLGqhxY77csQAMYdNP2najDO_Jg,1407
|
|
@@ -20,7 +20,7 @@ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zK
|
|
|
20
20
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
|
|
21
21
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=zo6l0ZZUhXFu8wEFD76I0zSqFT9IXFKUKtyeaSwk380,1795
|
|
22
22
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=0ucN1U0EDTDqcyTPLLcsAX6FLTf2E_3toOY9p81gWYE,1858
|
|
23
|
-
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=
|
|
23
|
+
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=ioPrS0NYQhpFk1j8rAxqtbwpx1CwjJQnrJEBDqVy-As,3283
|
|
24
24
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py,sha256=54yoa3NpV64H-koURt-hUWpFHhUjwXpGvXPp2_ETCnw,825
|
|
25
25
|
rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py,sha256=hQkpyaa0eUyB4V3UVks7l1_dHwOrbump0FZILityBKw,611
|
|
26
26
|
rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py,sha256=gk6QdoT1oq5B8Hp8Ak-SGqHm8CEj3MErUeWcRsaaOQM,1470
|
|
@@ -36,8 +36,8 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
|
|
|
36
36
|
rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
|
|
37
37
|
rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
|
|
38
38
|
rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
|
|
39
|
-
rwkv_ops-0.1.
|
|
40
|
-
rwkv_ops-0.1.
|
|
41
|
-
rwkv_ops-0.1.
|
|
42
|
-
rwkv_ops-0.1.
|
|
43
|
-
rwkv_ops-0.1.
|
|
39
|
+
rwkv_ops-0.1.1.dist-info/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
40
|
+
rwkv_ops-0.1.1.dist-info/METADATA,sha256=g2e5rhSz-SFLzyj76FbShCNtgdWAjgTV0ukw3WYR2fo,3608
|
|
41
|
+
rwkv_ops-0.1.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
42
|
+
rwkv_ops-0.1.1.dist-info/top_level.txt,sha256=cVqoKE-WR_e2gHL87-6O4K1kG6-yTJGB2huyr6FmD2I,9
|
|
43
|
+
rwkv_ops-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|