rwkv-ops 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

rwkv_ops/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.3.0"
1
+ __version__ = "0.3.2"
2
2
  import os
3
3
 
4
4
  KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "cuda").lower()
@@ -15,6 +15,9 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
15
15
  USE_TRITON_KERNEL = False
16
16
  if keras.config.backend() == "torch":
17
17
  import torch
18
+ if not torch.cuda.is_available():
19
+ from .native_keras_op import generalized_delta_rule
20
+ return generalized_delta_rule,False
18
21
 
19
22
  if KERNEL_TYPE.lower() == "triton":
20
23
  from .torch_op import generalized_delta_rule
@@ -152,10 +155,10 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
152
155
 
153
156
  USE_TRITON_KERNEL = False
154
157
  elif keras.config.backend() == "jax":
155
- from jax.lib import xla_bridge
158
+ import jax
156
159
  import os
157
-
158
- if xla_bridge.get_backend().platform == "gpu":
160
+
161
+ if jax.devices()[0].platform == "gpu":
159
162
  if KERNEL_TYPE.lower() == "triton":
160
163
  os.environ["JAX_LOG_COMPUTATION"] = "0"
161
164
  from .jax_op import generalized_delta_rule
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rwkv-ops
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
5
5
  Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
6
6
  Author-email: pass-lin <qw_lin@qq.com>
@@ -1,4 +1,4 @@
1
- rwkv_ops/__init__.py,sha256=0RQo3fmgbhzVE7PepXXtH09vEtrQqTcEm2cXnp-SZuA,855
1
+ rwkv_ops/__init__.py,sha256=ojPQmkz3yWNqqJwIyjAfsxWB_h3TowtBrJtuRqssEvA,855
2
2
  rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
3
3
  rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=4SL93Z4mmuQldHtmwqTKcP7M-outTU5Rge2qgDGzwBg,29966
4
4
  rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=c3ZSJ9xC6-PKr88pOhjmBximdhwmP1_i7UOcIdKB43c,3354
@@ -15,7 +15,7 @@ rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaD
15
15
  rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
16
16
  rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
17
17
  rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
18
- rwkv_ops/rwkv7_kernel/__init__.py,sha256=Sa4XlP6VrCSMh-IpRyGfvnCzEZdEGUKDN2df-pNBoRw,7994
18
+ rwkv_ops/rwkv7_kernel/__init__.py,sha256=HfoB043qxcIyljNcSd_XtH2UKB6wF2qQlOq9VvXwWRI,8129
19
19
  rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
20
20
  rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
21
21
  rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
@@ -58,7 +58,7 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
58
58
  rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
59
59
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
60
60
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
61
- rwkv_ops-0.3.0.dist-info/METADATA,sha256=LRNkirkN1YhiWYN8gy-g9vXkhORjfouJ_HVjUqvWYtw,8853
62
- rwkv_ops-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
- rwkv_ops-0.3.0.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
64
- rwkv_ops-0.3.0.dist-info/RECORD,,
61
+ rwkv_ops-0.3.2.dist-info/METADATA,sha256=lkSey3fiZxPrVO05sSb7Q4Q2cAHFgo8-f8RZjmLAWL4,8853
62
+ rwkv_ops-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
+ rwkv_ops-0.3.2.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
64
+ rwkv_ops-0.3.2.dist-info/RECORD,,