brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250105__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,7 @@ class TestJitError(unittest.TestCase):
32
32
  def err_f(x):
33
33
  raise ValueError(f'error: {x}')
34
34
 
35
+ bst.compile.jit_error_if(False, err_f, 1.)
35
36
  with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
36
37
  bst.compile.jit_error_if(True, err_f, 1.)
37
38
 
@@ -206,7 +206,7 @@ def scan(
206
206
 
207
207
  # evaluate jaxpr, get all states #
208
208
  # ------------------------------ #
209
- xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
209
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
210
210
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
211
211
  stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
212
212
  state_trace = stateful_fun.get_state_trace()
@@ -302,7 +302,7 @@ def checkpointed_scan(
302
302
  pbar_runner = None
303
303
 
304
304
  # evaluate jaxpr
305
- xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
305
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
306
306
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
307
307
  stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
308
308
  state_trace = stateful_fun.get_state_trace()
@@ -18,6 +18,7 @@ from __future__ import annotations
18
18
  import unittest
19
19
 
20
20
  import jax
21
+ import jax.extend as je
21
22
  import jax.numpy as jnp
22
23
  import pytest
23
24
 
@@ -84,7 +85,7 @@ class TestMakeJaxpr(unittest.TestCase):
84
85
  print(jaxpr)
85
86
  jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
86
87
  print(jaxpr)
87
- self.assertTrue(jnp.allclose(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
88
+ self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
88
89
  f3(jnp.zeros(1))))
89
90
 
90
91
  def test_compar_jax_make_jaxpr2(self):
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  import jax
18
18
  import jax.core
19
+ import jax.extend as je
19
20
  import jax.interpreters.batching as batching
20
21
  import jax.interpreters.mlir as mlir
21
22
  import jax.numpy as jnp
@@ -43,7 +44,7 @@ def unvmap(x, op: str = 'any'):
43
44
 
44
45
  # unvmap_all
45
46
 
46
- unvmap_all_p = jax.core.Primitive("unvmap_all")
47
+ unvmap_all_p = je.core.Primitive("unvmap_all")
47
48
 
48
49
 
49
50
  def unvmap_all(x):
brainstate/environ.py CHANGED
@@ -249,7 +249,13 @@ def get_precision() -> int:
249
249
  The default precision.
250
250
  """
251
251
  precision = get('precision')
252
- return 16 if precision == 'bf16' else precision
252
+ if precision == 'bf16':
253
+ return 16
254
+ if isinstance(precision, int):
255
+ return precision
256
+ if isinstance(precision, str):
257
+ return int(precision)
258
+ raise ValueError(f'Unsupported precision: {precision}')
253
259
 
254
260
 
255
261
  def set(
@@ -345,8 +351,8 @@ def _set_jax_precision(precision: int | str):
345
351
  Args:
346
352
  precision: int. The default precision.
347
353
  """
348
- assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
349
- if precision == 64:
354
+ # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
355
+ if precision in [64, '64']:
350
356
  config.update("jax_enable_x64", True)
351
357
  else:
352
358
  config.update("jax_enable_x64", False)
@@ -354,13 +360,13 @@ def _set_jax_precision(precision: int | str):
354
360
 
355
361
  @functools.lru_cache()
356
362
  def _get_uint(precision: int):
357
- if precision == 64:
363
+ if precision in [64, '64']:
358
364
  return np.uint64
359
- elif precision == 32:
365
+ elif precision in [32, '32']:
360
366
  return np.uint32
361
- elif precision in [16, 'bf16']:
367
+ elif precision in [16, '16', 'bf16']:
362
368
  return np.uint16
363
- elif precision == 8:
369
+ elif precision in [8, '8']:
364
370
  return np.uint8
365
371
  else:
366
372
  raise ValueError(f'Unsupported precision: {precision}')
@@ -368,13 +374,13 @@ def _get_uint(precision: int):
368
374
 
369
375
  @functools.lru_cache()
370
376
  def _get_int(precision: int):
371
- if precision == 64:
377
+ if precision in [64, '64']:
372
378
  return np.int64
373
- elif precision == 32:
379
+ elif precision in [32, '32']:
374
380
  return np.int32
375
- elif precision in [16, 'bf16']:
381
+ elif precision in [16, '16', 'bf16']:
376
382
  return np.int16
377
- elif precision == 8:
383
+ elif precision in [8, '8']:
378
384
  return np.int8
379
385
  else:
380
386
  raise ValueError(f'Unsupported precision: {precision}')
@@ -382,25 +388,29 @@ def _get_int(precision: int):
382
388
 
383
389
  @functools.lru_cache()
384
390
  def _get_float(precision: int):
385
- if precision == 64:
391
+ if precision in [64, '64']:
386
392
  return np.float64
387
- elif precision == 32:
393
+ elif precision in [32, '32']:
388
394
  return np.float32
389
- elif precision == 16:
395
+ elif precision in [16, '16']:
390
396
  return np.float16
391
- elif precision == 'bf16':
397
+ elif precision in ['bf16']:
392
398
  return jnp.bfloat16
399
+ elif precision in [8, '8']:
400
+ return jnp.float8_e5m2
393
401
  else:
394
402
  raise ValueError(f'Unsupported precision: {precision}')
395
403
 
396
404
 
397
405
  @functools.lru_cache()
398
406
  def _get_complex(precision: int):
399
- if precision == 64:
407
+ if precision == [64, '64']:
400
408
  return np.complex128
401
- elif precision == 32:
409
+ elif precision == [32, '32']:
402
410
  return np.complex64
403
- elif precision in [16, 'bf16']:
411
+ elif precision in [16, '16', 'bf16']:
412
+ return np.complex64
413
+ elif precision == [8, '8']:
404
414
  return np.complex64
405
415
  else:
406
416
  raise ValueError(f'Unsupported precision: {precision}')
@@ -32,6 +32,10 @@ class TestEnviron(unittest.TestCase):
32
32
  self.assertEqual(a.dtype, jnp.float32)
33
33
 
34
34
  with bst.environ.context(precision=16):
35
+ a = bst.random.randn(1)
36
+ self.assertEqual(a.dtype, jnp.float16)
37
+
38
+ with bst.environ.context(precision='bf16'):
35
39
  a = bst.random.randn(1)
36
40
  self.assertEqual(a.dtype, jnp.bfloat16)
37
41
 
@@ -24,6 +24,7 @@ import jax.numpy as jnp
24
24
  import numpy as np
25
25
  from jax.interpreters import ad
26
26
 
27
+ from brainstate import environ
27
28
  from brainstate._state import ParamState
28
29
  from brainstate.augment import vmap
29
30
  from brainstate.init import param
@@ -111,7 +112,7 @@ class FixedProb(Module):
111
112
 
112
113
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
113
114
  if self.n_conn > 1:
114
- return event_fixed_prob(
115
+ r = event_fixed_prob(
115
116
  spk,
116
117
  self.weight.value,
117
118
  self.indices,
@@ -123,7 +124,8 @@ class FixedProb(Module):
123
124
  weight = self.weight.value
124
125
  unit = u.get_unit(weight)
125
126
  r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
126
- return u.maybe_decimal(u.Quantity(r, unit=unit))
127
+ r = u.maybe_decimal(u.Quantity(r, unit=unit))
128
+ return u.math.asarray(r, dtype=environ.dftype())
127
129
 
128
130
 
129
131
  def event_fixed_prob(
@@ -128,4 +128,5 @@ class TestFixedProbCSR(parameterized.TestCase):
128
128
 
129
129
  o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
130
130
  self.assertTrue(jnp.allclose(o1, o2))
131
- self.assertTrue(jnp.allclose(r1, r2))
131
+ # assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
132
+ self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
@@ -12,9 +12,13 @@ from jax import tree_util
12
12
  from jax.core import Primitive
13
13
  from jax.interpreters import batching, ad
14
14
  from jax.interpreters import xla, mlir
15
- from jax.lib import xla_client
16
15
  from jaxlib.hlo_helpers import custom_call
17
16
 
17
+ if jax.__version_info__ < (0, 4, 35):
18
+ from jax.lib import xla_client
19
+ else:
20
+ import jax.extend as je
21
+
18
22
  numba_installed = importlib.util.find_spec('numba') is not None
19
23
 
20
24
  __all__ = [
@@ -143,7 +147,10 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
143
147
  xla_c_rule = cfunc(sig)(new_f)
144
148
  target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
145
149
  capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
146
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
150
+ if jax.__version_info__ < (0, 4, 35):
151
+ xla_client.register_custom_call_target(target_name, capsule, "cpu")
152
+ else:
153
+ je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
147
154
 
148
155
  # call
149
156
  return custom_call(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250104
3
+ Version: 0.1.0.post20250105
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -2,8 +2,8 @@ brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
2
2
  brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
3
3
  brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
- brainstate/environ.py,sha256=wCVlSjavJ9OXc1STOucg-VfXeK9KE443a1SDhkK9lA8,17270
6
- brainstate/environ_test.py,sha256=jXX3nR1CO74aow5YqfqSd73isj9MWgHQxrwSsEjTDY8,1901
5
+ brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
6
+ brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
7
7
  brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
8
8
  brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
9
9
  brainstate/surrogate.py,sha256=YaY6RJ6kzpuPXWFjaWsxWt2MzJfdm5v_jeOR8V_jPoU,48369
@@ -23,17 +23,17 @@ brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJC
23
23
  brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
24
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
25
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
26
- brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
26
+ brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
27
27
  brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
28
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
29
- brainstate/compile/_loop_collect_return.py,sha256=_iOVPytbctgyaIOxQZH3A2ZbsSoT7VXnFk6Q6R8-gvA,23360
29
+ brainstate/compile/_loop_collect_return.py,sha256=XwMnKkMH0xTWB1f6GE4NQNK1R2GXTXCiVgulpkdIpc4,23308
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
33
  brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
34
- brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
34
+ brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
35
35
  brainstate/compile/_progress_bar.py,sha256=eInZPjiqzYE6PWxl_or_lBthDNDO0Ov60Uz0DbuBbZQ,4620
36
- brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
36
+ brainstate/compile/_unvmap.py,sha256=0i-NvCLDAUe-effJIIEPVsK4WTPbCDBTgw6AqRvq7mE,4163
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
38
  brainstate/event/__init__.py,sha256=W0ZxbgrcFuYhWTl-GZ0UDoMGfsWmesvG4J_LbTBkex8,937
39
39
  brainstate/event/_csr.py,sha256=QDccbgXUklE2iq1w6WdyaFspXY1165uA9UlPltX16OU,30365
@@ -41,14 +41,14 @@ brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,1
41
41
  brainstate/event/_csr_mv_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
42
42
  brainstate/event/_csr_mv_test.py,sha256=WQfAvp_3UeCUGAZjr3_aqQvrB-eYZcFEN4v1PBe9fUQ,4012
43
43
  brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
44
- brainstate/event/_fixedprob_mv.py,sha256=HP5uyFwue5ZNhsU71ZedMQ-Kp5-st89aLKGNhCBmBRA,25457
44
+ brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
45
45
  brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
46
- brainstate/event/_fixedprob_mv_test.py,sha256=jijrtJ5fnwcLmA7Tjd3vDlzwfbftmLoVTN4-MPuogVc,4201
46
+ brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
47
47
  brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
48
48
  brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
49
49
  brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
50
50
  brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
51
- brainstate/event/_xla_custom_op.py,sha256=sR06amc_3mbQDq_ONdxHwr_O8ZKj7S5SNNgo62ILR3U,10797
51
+ brainstate/event/_xla_custom_op.py,sha256=f0OrO6CjsJOUNUCRPpHIRmsb_wgNEym0xBl1tcz8ij4,11016
52
52
  brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
53
53
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
54
54
  brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
137
137
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
138
138
  brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
139
139
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
140
- brainstate-0.1.0.post20250104.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
- brainstate-0.1.0.post20250104.dist-info/METADATA,sha256=pGOGZBCF8q5La6-DIO1hdldslPnnt4bfjVnOVWgHRU4,3533
142
- brainstate-0.1.0.post20250104.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
- brainstate-0.1.0.post20250104.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
- brainstate-0.1.0.post20250104.dist-info/RECORD,,
140
+ brainstate-0.1.0.post20250105.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
+ brainstate-0.1.0.post20250105.dist-info/METADATA,sha256=Xec1GNBlHcignyvym-EHzU-JIOUuo3T-IUU2LoCO0sk,3533
142
+ brainstate-0.1.0.post20250105.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
+ brainstate-0.1.0.post20250105.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
+ brainstate-0.1.0.post20250105.dist-info/RECORD,,