rwkv-ops 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

rwkv_ops/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.3.1"
1
+ __version__ = "0.3.2"
2
2
  import os
3
3
 
4
4
  KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "cuda").lower()
@@ -155,10 +155,10 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
155
155
 
156
156
  USE_TRITON_KERNEL = False
157
157
  elif keras.config.backend() == "jax":
158
- from jax.lib import xla_bridge
158
+ import jax
159
159
  import os
160
-
161
- if xla_bridge.get_backend().platform == "gpu":
160
+
161
+ if jax.devices()[0].platform == "gpu":
162
162
  if KERNEL_TYPE.lower() == "triton":
163
163
  os.environ["JAX_LOG_COMPUTATION"] = "0"
164
164
  from .jax_op import generalized_delta_rule
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rwkv-ops
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
5
5
  Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
6
6
  Author-email: pass-lin <qw_lin@qq.com>
@@ -1,4 +1,4 @@
1
- rwkv_ops/__init__.py,sha256=aQr_HZZ-tB6b_HTjnrbRNyPnExrq0qYReK7-UTqdFTA,855
1
+ rwkv_ops/__init__.py,sha256=ojPQmkz3yWNqqJwIyjAfsxWB_h3TowtBrJtuRqssEvA,855
2
2
  rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
3
3
  rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=4SL93Z4mmuQldHtmwqTKcP7M-outTU5Rge2qgDGzwBg,29966
4
4
  rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=c3ZSJ9xC6-PKr88pOhjmBximdhwmP1_i7UOcIdKB43c,3354
@@ -15,7 +15,7 @@ rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaD
15
15
  rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
16
16
  rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
17
17
  rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
18
- rwkv_ops/rwkv7_kernel/__init__.py,sha256=1TFr3m5kLC0Ldv0HEnzt58_g84e7KD81R6s8TpmiWmI,8148
18
+ rwkv_ops/rwkv7_kernel/__init__.py,sha256=HfoB043qxcIyljNcSd_XtH2UKB6wF2qQlOq9VvXwWRI,8129
19
19
  rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
20
20
  rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
21
21
  rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
@@ -58,7 +58,7 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
58
58
  rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
59
59
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
60
60
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
61
- rwkv_ops-0.3.1.dist-info/METADATA,sha256=cCo4QhAeyVGbAMMO7FC0PxskVRYPmKU30FUYbrqoVnw,8853
62
- rwkv_ops-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
- rwkv_ops-0.3.1.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
64
- rwkv_ops-0.3.1.dist-info/RECORD,,
61
+ rwkv_ops-0.3.2.dist-info/METADATA,sha256=lkSey3fiZxPrVO05sSb7Q4Q2cAHFgo8-f8RZjmLAWL4,8853
62
+ rwkv_ops-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
+ rwkv_ops-0.3.2.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
64
+ rwkv_ops-0.3.2.dist-info/RECORD,,