brainstate 0.1.0.post20250102__py2.py3-none-any.whl → 0.1.0.post20250105__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,7 @@ class TestJitError(unittest.TestCase):
32
32
  def err_f(x):
33
33
  raise ValueError(f'error: {x}')
34
34
 
35
+ bst.compile.jit_error_if(False, err_f, 1.)
35
36
  with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
36
37
  bst.compile.jit_error_if(True, err_f, 1.)
37
38
 
@@ -206,7 +206,7 @@ def scan(
206
206
 
207
207
  # evaluate jaxpr, get all states #
208
208
  # ------------------------------ #
209
- xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
209
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
210
210
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
211
211
  stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
212
212
  state_trace = stateful_fun.get_state_trace()
@@ -302,7 +302,7 @@ def checkpointed_scan(
302
302
  pbar_runner = None
303
303
 
304
304
  # evaluate jaxpr
305
- xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
305
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
306
306
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
307
307
  stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
308
308
  state_trace = stateful_fun.get_state_trace()
@@ -18,6 +18,7 @@ from __future__ import annotations
18
18
  import unittest
19
19
 
20
20
  import jax
21
+ import jax.extend as je
21
22
  import jax.numpy as jnp
22
23
  import pytest
23
24
 
@@ -84,7 +85,7 @@ class TestMakeJaxpr(unittest.TestCase):
84
85
  print(jaxpr)
85
86
  jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
86
87
  print(jaxpr)
87
- self.assertTrue(jnp.allclose(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
88
+ self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
88
89
  f3(jnp.zeros(1))))
89
90
 
90
91
  def test_compar_jax_make_jaxpr2(self):
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  import jax
18
18
  import jax.core
19
+ import jax.extend as je
19
20
  import jax.interpreters.batching as batching
20
21
  import jax.interpreters.mlir as mlir
21
22
  import jax.numpy as jnp
@@ -43,7 +44,7 @@ def unvmap(x, op: str = 'any'):
43
44
 
44
45
  # unvmap_all
45
46
 
46
- unvmap_all_p = jax.core.Primitive("unvmap_all")
47
+ unvmap_all_p = je.core.Primitive("unvmap_all")
47
48
 
48
49
 
49
50
  def unvmap_all(x):
brainstate/environ.py CHANGED
@@ -31,13 +31,12 @@ from jax import config, devices, numpy as jnp
31
31
  from jax.typing import DTypeLike
32
32
 
33
33
  from .mixin import Mode
34
- from .util import MemScaling
35
34
 
36
35
  __all__ = [
37
36
  # functions for environment settings
38
37
  'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
39
38
  # functions for getting default behaviors
40
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
39
+ 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
41
40
  # functions for default data types
42
41
  'dftype', 'ditype', 'dutype', 'dctype',
43
42
  # others
@@ -94,7 +93,7 @@ def context(**kwargs):
94
93
  'Please use set_host_device_count() or set() for the global setting.')
95
94
 
96
95
  if 'precision' in kwargs:
97
- last_precision = get_precision()
96
+ last_precision = _get_precision()
98
97
  _set_jax_precision(kwargs['precision'])
99
98
 
100
99
  try:
@@ -203,17 +202,6 @@ def get_mode() -> Mode:
203
202
  return get('mode')
204
203
 
205
204
 
206
- def get_mem_scaling() -> MemScaling:
207
- """Get the default computing membrane_scaling.
208
-
209
- Returns
210
- -------
211
- membrane_scaling: MemScaling
212
- The default computing membrane_scaling.
213
- """
214
- return get('mem_scaling')
215
-
216
-
217
205
  def get_platform() -> str:
218
206
  """Get the computing platform.
219
207
 
@@ -239,7 +227,7 @@ def get_host_device_count():
239
227
  return int(match.group(1)) if match else 1
240
228
 
241
229
 
242
- def get_precision() -> int:
230
+ def _get_precision() -> int | str:
243
231
  """
244
232
  Get the default precision.
245
233
 
@@ -251,11 +239,29 @@ def get_precision() -> int:
251
239
  return get('precision')
252
240
 
253
241
 
242
+ def get_precision() -> int:
243
+ """
244
+ Get the default precision.
245
+
246
+ Returns
247
+ -------
248
+ precision: int
249
+ The default precision.
250
+ """
251
+ precision = get('precision')
252
+ if precision == 'bf16':
253
+ return 16
254
+ if isinstance(precision, int):
255
+ return precision
256
+ if isinstance(precision, str):
257
+ return int(precision)
258
+ raise ValueError(f'Unsupported precision: {precision}')
259
+
260
+
254
261
  def set(
255
262
  platform: str = None,
256
263
  host_device_count: int = None,
257
- mem_scaling: MemScaling = None,
258
- precision: int = None,
264
+ precision: int | str = None,
259
265
  mode: Mode = None,
260
266
  **kwargs
261
267
  ):
@@ -267,8 +273,7 @@ def set(
267
273
  Args:
268
274
  platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
269
275
  host_device_count: int. The number of host devices.
270
- mem_scaling: MemScaling. The membrane scaling.
271
- precision: int. The default precision.
276
+ precision: int, str. The default precision.
272
277
  mode: Mode. The computing mode.
273
278
  **kwargs: dict. Other environment settings.
274
279
  """
@@ -276,9 +281,6 @@ def set(
276
281
  set_platform(platform)
277
282
  if host_device_count is not None:
278
283
  set_host_device_count(host_device_count)
279
- if mem_scaling is not None:
280
- assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
281
- kwargs['mem_scaling'] = mem_scaling
282
284
  if precision is not None:
283
285
  _set_jax_precision(precision)
284
286
  kwargs['precision'] = precision
@@ -342,15 +344,15 @@ def set_platform(platform: str):
342
344
  DFAULT.functions['platform'](platform)
343
345
 
344
346
 
345
- def _set_jax_precision(precision: int):
347
+ def _set_jax_precision(precision: int | str):
346
348
  """
347
349
  Set the default precision.
348
350
 
349
351
  Args:
350
352
  precision: int. The default precision.
351
353
  """
352
- assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
353
- if precision == 64:
354
+ # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
355
+ if precision in [64, '64']:
354
356
  config.update("jax_enable_x64", True)
355
357
  else:
356
358
  config.update("jax_enable_x64", False)
@@ -358,13 +360,13 @@ def _set_jax_precision(precision: int):
358
360
 
359
361
  @functools.lru_cache()
360
362
  def _get_uint(precision: int):
361
- if precision == 64:
363
+ if precision in [64, '64']:
362
364
  return np.uint64
363
- elif precision == 32:
365
+ elif precision in [32, '32']:
364
366
  return np.uint32
365
- elif precision == 16:
367
+ elif precision in [16, '16', 'bf16']:
366
368
  return np.uint16
367
- elif precision == 8:
369
+ elif precision in [8, '8']:
368
370
  return np.uint8
369
371
  else:
370
372
  raise ValueError(f'Unsupported precision: {precision}')
@@ -372,13 +374,13 @@ def _get_uint(precision: int):
372
374
 
373
375
  @functools.lru_cache()
374
376
  def _get_int(precision: int):
375
- if precision == 64:
377
+ if precision in [64, '64']:
376
378
  return np.int64
377
- elif precision == 32:
379
+ elif precision in [32, '32']:
378
380
  return np.int32
379
- elif precision == 16:
381
+ elif precision in [16, '16', 'bf16']:
380
382
  return np.int16
381
- elif precision == 8:
383
+ elif precision in [8, '8']:
382
384
  return np.int8
383
385
  else:
384
386
  raise ValueError(f'Unsupported precision: {precision}')
@@ -386,25 +388,30 @@ def _get_int(precision: int):
386
388
 
387
389
  @functools.lru_cache()
388
390
  def _get_float(precision: int):
389
- if precision == 64:
391
+ if precision in [64, '64']:
390
392
  return np.float64
391
- elif precision == 32:
393
+ elif precision in [32, '32']:
392
394
  return np.float32
393
- elif precision == 16:
395
+ elif precision in [16, '16']:
396
+ return np.float16
397
+ elif precision in ['bf16']:
394
398
  return jnp.bfloat16
395
- # return np.float16
399
+ elif precision in [8, '8']:
400
+ return jnp.float8_e5m2
396
401
  else:
397
402
  raise ValueError(f'Unsupported precision: {precision}')
398
403
 
399
404
 
400
405
  @functools.lru_cache()
401
406
  def _get_complex(precision: int):
402
- if precision == 64:
407
+ if precision == [64, '64']:
403
408
  return np.complex128
404
- elif precision == 32:
409
+ elif precision == [32, '32']:
410
+ return np.complex64
411
+ elif precision in [16, '16', 'bf16']:
412
+ return np.complex64
413
+ elif precision == [8, '8']:
405
414
  return np.complex64
406
- elif precision == 16:
407
- return np.complex32
408
415
  else:
409
416
  raise ValueError(f'Unsupported precision: {precision}')
410
417
 
@@ -430,7 +437,7 @@ def dftype() -> DTypeLike:
430
437
  float_dtype: DTypeLike
431
438
  The default floating data type.
432
439
  """
433
- return _get_float(get_precision())
440
+ return _get_float(_get_precision())
434
441
 
435
442
 
436
443
  def ditype() -> DTypeLike:
@@ -455,7 +462,7 @@ def ditype() -> DTypeLike:
455
462
  int_dtype: DTypeLike
456
463
  The default integer data type.
457
464
  """
458
- return _get_int(get_precision())
465
+ return _get_int(_get_precision())
459
466
 
460
467
 
461
468
  def dutype() -> DTypeLike:
@@ -481,7 +488,7 @@ def dutype() -> DTypeLike:
481
488
  uint_dtype: DTypeLike
482
489
  The default unsigned integer data type.
483
490
  """
484
- return _get_uint(get_precision())
491
+ return _get_uint(_get_precision())
485
492
 
486
493
 
487
494
  def dctype() -> DTypeLike:
@@ -506,7 +513,7 @@ def dctype() -> DTypeLike:
506
513
  complex_dtype: DTypeLike
507
514
  The default complex data type.
508
515
  """
509
- return _get_complex(get_precision())
516
+ return _get_complex(_get_precision())
510
517
 
511
518
 
512
519
  def tolerance():
@@ -32,6 +32,10 @@ class TestEnviron(unittest.TestCase):
32
32
  self.assertEqual(a.dtype, jnp.float32)
33
33
 
34
34
  with bst.environ.context(precision=16):
35
+ a = bst.random.randn(1)
36
+ self.assertEqual(a.dtype, jnp.float16)
37
+
38
+ with bst.environ.context(precision='bf16'):
35
39
  a = bst.random.randn(1)
36
40
  self.assertEqual(a.dtype, jnp.bfloat16)
37
41
 
@@ -24,6 +24,7 @@ import jax.numpy as jnp
24
24
  import numpy as np
25
25
  from jax.interpreters import ad
26
26
 
27
+ from brainstate import environ
27
28
  from brainstate._state import ParamState
28
29
  from brainstate.augment import vmap
29
30
  from brainstate.init import param
@@ -111,7 +112,7 @@ class FixedProb(Module):
111
112
 
112
113
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
113
114
  if self.n_conn > 1:
114
- return event_fixed_prob(
115
+ r = event_fixed_prob(
115
116
  spk,
116
117
  self.weight.value,
117
118
  self.indices,
@@ -123,7 +124,8 @@ class FixedProb(Module):
123
124
  weight = self.weight.value
124
125
  unit = u.get_unit(weight)
125
126
  r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
126
- return u.maybe_decimal(u.Quantity(r, unit=unit))
127
+ r = u.maybe_decimal(u.Quantity(r, unit=unit))
128
+ return u.math.asarray(r, dtype=environ.dftype())
127
129
 
128
130
 
129
131
  def event_fixed_prob(
@@ -128,4 +128,5 @@ class TestFixedProbCSR(parameterized.TestCase):
128
128
 
129
129
  o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
130
130
  self.assertTrue(jnp.allclose(o1, o2))
131
- self.assertTrue(jnp.allclose(r1, r2))
131
+ # assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
132
+ self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
@@ -12,9 +12,13 @@ from jax import tree_util
12
12
  from jax.core import Primitive
13
13
  from jax.interpreters import batching, ad
14
14
  from jax.interpreters import xla, mlir
15
- from jax.lib import xla_client
16
15
  from jaxlib.hlo_helpers import custom_call
17
16
 
17
+ if jax.__version_info__ < (0, 4, 35):
18
+ from jax.lib import xla_client
19
+ else:
20
+ import jax.extend as je
21
+
18
22
  numba_installed = importlib.util.find_spec('numba') is not None
19
23
 
20
24
  __all__ = [
@@ -143,7 +147,10 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
143
147
  xla_c_rule = cfunc(sig)(new_f)
144
148
  target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
145
149
  capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
146
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
150
+ if jax.__version_info__ < (0, 4, 35):
151
+ xla_client.register_custom_call_target(target_name, capsule, "cpu")
152
+ else:
153
+ je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
147
154
 
148
155
  # call
149
156
  return custom_call(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250102
3
+ Version: 0.1.0.post20250105
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -2,8 +2,8 @@ brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
2
2
  brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
3
3
  brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
- brainstate/environ.py,sha256=G6r_rqfbofRbjFFalRu_DHaL7ruFTeLRXBQDXM6P-tQ,17477
6
- brainstate/environ_test.py,sha256=jXX3nR1CO74aow5YqfqSd73isj9MWgHQxrwSsEjTDY8,1901
5
+ brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
6
+ brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
7
7
  brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
8
8
  brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
9
9
  brainstate/surrogate.py,sha256=YaY6RJ6kzpuPXWFjaWsxWt2MzJfdm5v_jeOR8V_jPoU,48369
@@ -23,17 +23,17 @@ brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJC
23
23
  brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
24
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
25
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
26
- brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
26
+ brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
27
27
  brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
28
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
29
- brainstate/compile/_loop_collect_return.py,sha256=_iOVPytbctgyaIOxQZH3A2ZbsSoT7VXnFk6Q6R8-gvA,23360
29
+ brainstate/compile/_loop_collect_return.py,sha256=XwMnKkMH0xTWB1f6GE4NQNK1R2GXTXCiVgulpkdIpc4,23308
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
33
  brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
34
- brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
34
+ brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
35
35
  brainstate/compile/_progress_bar.py,sha256=eInZPjiqzYE6PWxl_or_lBthDNDO0Ov60Uz0DbuBbZQ,4620
36
- brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
36
+ brainstate/compile/_unvmap.py,sha256=0i-NvCLDAUe-effJIIEPVsK4WTPbCDBTgw6AqRvq7mE,4163
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
38
  brainstate/event/__init__.py,sha256=W0ZxbgrcFuYhWTl-GZ0UDoMGfsWmesvG4J_LbTBkex8,937
39
39
  brainstate/event/_csr.py,sha256=QDccbgXUklE2iq1w6WdyaFspXY1165uA9UlPltX16OU,30365
@@ -41,14 +41,14 @@ brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,1
41
41
  brainstate/event/_csr_mv_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
42
42
  brainstate/event/_csr_mv_test.py,sha256=WQfAvp_3UeCUGAZjr3_aqQvrB-eYZcFEN4v1PBe9fUQ,4012
43
43
  brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
44
- brainstate/event/_fixedprob_mv.py,sha256=HP5uyFwue5ZNhsU71ZedMQ-Kp5-st89aLKGNhCBmBRA,25457
44
+ brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
45
45
  brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
46
- brainstate/event/_fixedprob_mv_test.py,sha256=jijrtJ5fnwcLmA7Tjd3vDlzwfbftmLoVTN4-MPuogVc,4201
46
+ brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
47
47
  brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
48
48
  brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
49
49
  brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
50
50
  brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
51
- brainstate/event/_xla_custom_op.py,sha256=sR06amc_3mbQDq_ONdxHwr_O8ZKj7S5SNNgo62ILR3U,10797
51
+ brainstate/event/_xla_custom_op.py,sha256=f0OrO6CjsJOUNUCRPpHIRmsb_wgNEym0xBl1tcz8ij4,11016
52
52
  brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
53
53
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
54
54
  brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
137
137
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
138
138
  brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
139
139
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
140
- brainstate-0.1.0.post20250102.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
- brainstate-0.1.0.post20250102.dist-info/METADATA,sha256=KtQbKvFh7Z_WjFQ4e0s3IkEvtsag6MpOZEvCJQ5Mj5k,3533
142
- brainstate-0.1.0.post20250102.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
- brainstate-0.1.0.post20250102.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
- brainstate-0.1.0.post20250102.dist-info/RECORD,,
140
+ brainstate-0.1.0.post20250105.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
+ brainstate-0.1.0.post20250105.dist-info/METADATA,sha256=Xec1GNBlHcignyvym-EHzU-JIOUuo3T-IUU2LoCO0sk,3533
142
+ brainstate-0.1.0.post20250105.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
+ brainstate-0.1.0.post20250105.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
+ brainstate-0.1.0.post20250105.dist-info/RECORD,,