brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250105__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_unvmap.py +2 -1
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +9 -2
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/RECORD +14 -14
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/top_level.txt +0 -0
@@ -206,7 +206,7 @@ def scan(
|
|
206
206
|
|
207
207
|
# evaluate jaxpr, get all states #
|
208
208
|
# ------------------------------ #
|
209
|
-
xs_avals = [jax.core.
|
209
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
210
210
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
211
211
|
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
212
212
|
state_trace = stateful_fun.get_state_trace()
|
@@ -302,7 +302,7 @@ def checkpointed_scan(
|
|
302
302
|
pbar_runner = None
|
303
303
|
|
304
304
|
# evaluate jaxpr
|
305
|
-
xs_avals = [jax.core.
|
305
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
306
306
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
307
307
|
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
308
308
|
state_trace = stateful_fun.get_state_trace()
|
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import jax
|
21
|
+
import jax.extend as je
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import pytest
|
23
24
|
|
@@ -84,7 +85,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
84
85
|
print(jaxpr)
|
85
86
|
jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
|
86
87
|
print(jaxpr)
|
87
|
-
self.assertTrue(jnp.allclose(
|
88
|
+
self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
88
89
|
f3(jnp.zeros(1))))
|
89
90
|
|
90
91
|
def test_compar_jax_make_jaxpr2(self):
|
brainstate/compile/_unvmap.py
CHANGED
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import jax
|
18
18
|
import jax.core
|
19
|
+
import jax.extend as je
|
19
20
|
import jax.interpreters.batching as batching
|
20
21
|
import jax.interpreters.mlir as mlir
|
21
22
|
import jax.numpy as jnp
|
@@ -43,7 +44,7 @@ def unvmap(x, op: str = 'any'):
|
|
43
44
|
|
44
45
|
# unvmap_all
|
45
46
|
|
46
|
-
unvmap_all_p =
|
47
|
+
unvmap_all_p = je.core.Primitive("unvmap_all")
|
47
48
|
|
48
49
|
|
49
50
|
def unvmap_all(x):
|
brainstate/environ.py
CHANGED
@@ -249,7 +249,13 @@ def get_precision() -> int:
|
|
249
249
|
The default precision.
|
250
250
|
"""
|
251
251
|
precision = get('precision')
|
252
|
-
|
252
|
+
if precision == 'bf16':
|
253
|
+
return 16
|
254
|
+
if isinstance(precision, int):
|
255
|
+
return precision
|
256
|
+
if isinstance(precision, str):
|
257
|
+
return int(precision)
|
258
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
253
259
|
|
254
260
|
|
255
261
|
def set(
|
@@ -345,8 +351,8 @@ def _set_jax_precision(precision: int | str):
|
|
345
351
|
Args:
|
346
352
|
precision: int. The default precision.
|
347
353
|
"""
|
348
|
-
assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
349
|
-
if precision
|
354
|
+
# assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
355
|
+
if precision in [64, '64']:
|
350
356
|
config.update("jax_enable_x64", True)
|
351
357
|
else:
|
352
358
|
config.update("jax_enable_x64", False)
|
@@ -354,13 +360,13 @@ def _set_jax_precision(precision: int | str):
|
|
354
360
|
|
355
361
|
@functools.lru_cache()
|
356
362
|
def _get_uint(precision: int):
|
357
|
-
if precision
|
363
|
+
if precision in [64, '64']:
|
358
364
|
return np.uint64
|
359
|
-
elif precision
|
365
|
+
elif precision in [32, '32']:
|
360
366
|
return np.uint32
|
361
|
-
elif precision in [16, 'bf16']:
|
367
|
+
elif precision in [16, '16', 'bf16']:
|
362
368
|
return np.uint16
|
363
|
-
elif precision
|
369
|
+
elif precision in [8, '8']:
|
364
370
|
return np.uint8
|
365
371
|
else:
|
366
372
|
raise ValueError(f'Unsupported precision: {precision}')
|
@@ -368,13 +374,13 @@ def _get_uint(precision: int):
|
|
368
374
|
|
369
375
|
@functools.lru_cache()
|
370
376
|
def _get_int(precision: int):
|
371
|
-
if precision
|
377
|
+
if precision in [64, '64']:
|
372
378
|
return np.int64
|
373
|
-
elif precision
|
379
|
+
elif precision in [32, '32']:
|
374
380
|
return np.int32
|
375
|
-
elif precision in [16, 'bf16']:
|
381
|
+
elif precision in [16, '16', 'bf16']:
|
376
382
|
return np.int16
|
377
|
-
elif precision
|
383
|
+
elif precision in [8, '8']:
|
378
384
|
return np.int8
|
379
385
|
else:
|
380
386
|
raise ValueError(f'Unsupported precision: {precision}')
|
@@ -382,25 +388,29 @@ def _get_int(precision: int):
|
|
382
388
|
|
383
389
|
@functools.lru_cache()
|
384
390
|
def _get_float(precision: int):
|
385
|
-
if precision
|
391
|
+
if precision in [64, '64']:
|
386
392
|
return np.float64
|
387
|
-
elif precision
|
393
|
+
elif precision in [32, '32']:
|
388
394
|
return np.float32
|
389
|
-
elif precision
|
395
|
+
elif precision in [16, '16']:
|
390
396
|
return np.float16
|
391
|
-
elif precision
|
397
|
+
elif precision in ['bf16']:
|
392
398
|
return jnp.bfloat16
|
399
|
+
elif precision in [8, '8']:
|
400
|
+
return jnp.float8_e5m2
|
393
401
|
else:
|
394
402
|
raise ValueError(f'Unsupported precision: {precision}')
|
395
403
|
|
396
404
|
|
397
405
|
@functools.lru_cache()
|
398
406
|
def _get_complex(precision: int):
|
399
|
-
if precision == 64:
|
407
|
+
if precision == [64, '64']:
|
400
408
|
return np.complex128
|
401
|
-
elif precision == 32:
|
409
|
+
elif precision == [32, '32']:
|
402
410
|
return np.complex64
|
403
|
-
elif precision in [16, 'bf16']:
|
411
|
+
elif precision in [16, '16', 'bf16']:
|
412
|
+
return np.complex64
|
413
|
+
elif precision == [8, '8']:
|
404
414
|
return np.complex64
|
405
415
|
else:
|
406
416
|
raise ValueError(f'Unsupported precision: {precision}')
|
brainstate/environ_test.py
CHANGED
@@ -32,6 +32,10 @@ class TestEnviron(unittest.TestCase):
|
|
32
32
|
self.assertEqual(a.dtype, jnp.float32)
|
33
33
|
|
34
34
|
with bst.environ.context(precision=16):
|
35
|
+
a = bst.random.randn(1)
|
36
|
+
self.assertEqual(a.dtype, jnp.float16)
|
37
|
+
|
38
|
+
with bst.environ.context(precision='bf16'):
|
35
39
|
a = bst.random.randn(1)
|
36
40
|
self.assertEqual(a.dtype, jnp.bfloat16)
|
37
41
|
|
@@ -24,6 +24,7 @@ import jax.numpy as jnp
|
|
24
24
|
import numpy as np
|
25
25
|
from jax.interpreters import ad
|
26
26
|
|
27
|
+
from brainstate import environ
|
27
28
|
from brainstate._state import ParamState
|
28
29
|
from brainstate.augment import vmap
|
29
30
|
from brainstate.init import param
|
@@ -111,7 +112,7 @@ class FixedProb(Module):
|
|
111
112
|
|
112
113
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
113
114
|
if self.n_conn > 1:
|
114
|
-
|
115
|
+
r = event_fixed_prob(
|
115
116
|
spk,
|
116
117
|
self.weight.value,
|
117
118
|
self.indices,
|
@@ -123,7 +124,8 @@ class FixedProb(Module):
|
|
123
124
|
weight = self.weight.value
|
124
125
|
unit = u.get_unit(weight)
|
125
126
|
r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
126
|
-
|
127
|
+
r = u.maybe_decimal(u.Quantity(r, unit=unit))
|
128
|
+
return u.math.asarray(r, dtype=environ.dftype())
|
127
129
|
|
128
130
|
|
129
131
|
def event_fixed_prob(
|
@@ -128,4 +128,5 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
128
128
|
|
129
129
|
o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
130
130
|
self.assertTrue(jnp.allclose(o1, o2))
|
131
|
-
|
131
|
+
# assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
|
132
|
+
self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
|
@@ -12,9 +12,13 @@ from jax import tree_util
|
|
12
12
|
from jax.core import Primitive
|
13
13
|
from jax.interpreters import batching, ad
|
14
14
|
from jax.interpreters import xla, mlir
|
15
|
-
from jax.lib import xla_client
|
16
15
|
from jaxlib.hlo_helpers import custom_call
|
17
16
|
|
17
|
+
if jax.__version_info__ < (0, 4, 35):
|
18
|
+
from jax.lib import xla_client
|
19
|
+
else:
|
20
|
+
import jax.extend as je
|
21
|
+
|
18
22
|
numba_installed = importlib.util.find_spec('numba') is not None
|
19
23
|
|
20
24
|
__all__ = [
|
@@ -143,7 +147,10 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
|
143
147
|
xla_c_rule = cfunc(sig)(new_f)
|
144
148
|
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
|
145
149
|
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
|
146
|
-
|
150
|
+
if jax.__version_info__ < (0, 4, 35):
|
151
|
+
xla_client.register_custom_call_target(target_name, capsule, "cpu")
|
152
|
+
else:
|
153
|
+
je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
|
147
154
|
|
148
155
|
# call
|
149
156
|
return custom_call(
|
{brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250105
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -2,8 +2,8 @@ brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
|
|
2
2
|
brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
|
3
3
|
brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
|
-
brainstate/environ.py,sha256=
|
6
|
-
brainstate/environ_test.py,sha256=
|
5
|
+
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
6
|
+
brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
9
|
brainstate/surrogate.py,sha256=YaY6RJ6kzpuPXWFjaWsxWt2MzJfdm5v_jeOR8V_jPoU,48369
|
@@ -23,17 +23,17 @@ brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJC
|
|
23
23
|
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
|
-
brainstate/compile/_error_if_test.py,sha256=
|
26
|
+
brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
|
27
27
|
brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=XwMnKkMH0xTWB1f6GE4NQNK1R2GXTXCiVgulpkdIpc4,23308
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
33
|
brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
|
34
|
-
brainstate/compile/_make_jaxpr_test.py,sha256=
|
34
|
+
brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=eInZPjiqzYE6PWxl_or_lBthDNDO0Ov60Uz0DbuBbZQ,4620
|
36
|
-
brainstate/compile/_unvmap.py,sha256=
|
36
|
+
brainstate/compile/_unvmap.py,sha256=0i-NvCLDAUe-effJIIEPVsK4WTPbCDBTgw6AqRvq7mE,4163
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
38
|
brainstate/event/__init__.py,sha256=W0ZxbgrcFuYhWTl-GZ0UDoMGfsWmesvG4J_LbTBkex8,937
|
39
39
|
brainstate/event/_csr.py,sha256=QDccbgXUklE2iq1w6WdyaFspXY1165uA9UlPltX16OU,30365
|
@@ -41,14 +41,14 @@ brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,1
|
|
41
41
|
brainstate/event/_csr_mv_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
|
42
42
|
brainstate/event/_csr_mv_test.py,sha256=WQfAvp_3UeCUGAZjr3_aqQvrB-eYZcFEN4v1PBe9fUQ,4012
|
43
43
|
brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
|
44
|
-
brainstate/event/_fixedprob_mv.py,sha256=
|
44
|
+
brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
|
45
45
|
brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
|
46
|
-
brainstate/event/_fixedprob_mv_test.py,sha256=
|
46
|
+
brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
|
47
47
|
brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
|
48
48
|
brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
|
49
49
|
brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
|
50
50
|
brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
|
51
|
-
brainstate/event/_xla_custom_op.py,sha256=
|
51
|
+
brainstate/event/_xla_custom_op.py,sha256=f0OrO6CjsJOUNUCRPpHIRmsb_wgNEym0xBl1tcz8ij4,11016
|
52
52
|
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
53
53
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
54
54
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
|
|
137
137
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
138
138
|
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
139
139
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
144
|
-
brainstate-0.1.0.
|
140
|
+
brainstate-0.1.0.post20250105.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
141
|
+
brainstate-0.1.0.post20250105.dist-info/METADATA,sha256=Xec1GNBlHcignyvym-EHzU-JIOUuo3T-IUU2LoCO0sk,3533
|
142
|
+
brainstate-0.1.0.post20250105.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
143
|
+
brainstate-0.1.0.post20250105.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
144
|
+
brainstate-0.1.0.post20250105.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250105.dist-info}/top_level.txt
RENAMED
File without changes
|