rwkv-ops 0.2__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/PKG-INFO +1 -1
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/__init__.py +1 -1
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/__init__.py +0 -6
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_op.py +4 -3
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_op.py +24 -10
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops.egg-info/PKG-INFO +1 -1
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/setup.py +1 -1
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/LICENSE.txt +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/README.md +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv6_kernel/__init__.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/native_keras_op.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops.egg-info/SOURCES.txt +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops.egg-info/dependency_links.txt +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops.egg-info/requires.txt +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/rwkv_ops.egg-info/top_level.txt +0 -0
- {rwkv_ops-0.2 → rwkv_ops-0.2.2}/setup.cfg +0 -0
|
@@ -132,14 +132,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
132
132
|
USE_KERNEL = False
|
|
133
133
|
elif keras.config.backend() == "jax":
|
|
134
134
|
from jax.lib import xla_bridge
|
|
135
|
-
import jax
|
|
136
135
|
import os
|
|
137
|
-
import logging
|
|
138
136
|
|
|
139
|
-
logging.basicConfig(level=logging.ERROR)
|
|
140
|
-
os.environ["TRITON_LOG_LEVEL"] = "ERROR" # 只显示错误级别的日志
|
|
141
|
-
os.environ["TRITON_DISABLE_AUTOTUNE"] = "1" # 禁用自动调优日志
|
|
142
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 禁用自动调优日志
|
|
143
137
|
if (
|
|
144
138
|
xla_bridge.get_backend().platform == "gpu"
|
|
145
139
|
and KERNEL_TYPE.lower() == "triton"
|
|
@@ -14,10 +14,9 @@ from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
|
|
|
14
14
|
from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
|
|
15
15
|
from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
|
|
16
16
|
from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
|
|
17
|
-
|
|
17
|
+
from jax.ad_checkpoint import checkpoint_policies
|
|
18
18
|
CHUNKSIZE = 16
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
def chunk_dplr_fwd(
|
|
22
21
|
q: jax.Array,
|
|
23
22
|
k: jax.Array,
|
|
@@ -156,7 +155,7 @@ def chunk_dplr_fwd_jax(
|
|
|
156
155
|
output_final_state=True,
|
|
157
156
|
)
|
|
158
157
|
cache = (r, k, v, a, b, gk, initial_state)
|
|
159
|
-
return
|
|
158
|
+
return (o, state), cache
|
|
160
159
|
|
|
161
160
|
|
|
162
161
|
def chunk_dplr_bwd(
|
|
@@ -378,3 +377,5 @@ def generalized_delta_rule(
|
|
|
378
377
|
if output_final_state:
|
|
379
378
|
return jnp.asarray(o, DTYPE), final_state
|
|
380
379
|
return jnp.asarray(o, DTYPE)
|
|
380
|
+
|
|
381
|
+
|
|
@@ -473,16 +473,30 @@ def generalized_delta_rule(
|
|
|
473
473
|
a = transpose_head(a, head_first)
|
|
474
474
|
b = transpose_head(b, head_first)
|
|
475
475
|
w = transpose_head(w, head_first)
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
476
|
+
if w.device.type == "cuda":
|
|
477
|
+
out, state = chunk_rwkv7(
|
|
478
|
+
r=r,
|
|
479
|
+
k=k,
|
|
480
|
+
v=v,
|
|
481
|
+
a=a,
|
|
482
|
+
b=b,
|
|
483
|
+
w=w,
|
|
484
|
+
initial_state=initial_state,
|
|
485
|
+
output_final_state=output_final_state,
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
from ops.native_keras_op import generalized_delta_rule
|
|
489
|
+
|
|
490
|
+
out, state = generalized_delta_rule(
|
|
491
|
+
r=r,
|
|
492
|
+
k=k,
|
|
493
|
+
v=v,
|
|
494
|
+
a=a,
|
|
495
|
+
b=b,
|
|
496
|
+
w=w,
|
|
497
|
+
initial_state=initial_state,
|
|
498
|
+
output_final_state=output_final_state,
|
|
499
|
+
)
|
|
486
500
|
out = transpose_head(out, head_first)
|
|
487
501
|
if output_final_state:
|
|
488
502
|
return out, cast(state, dtype)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|