brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_jit.py +37 -28
  9. brainstate/compile/_loop_collect_return.py +6 -3
  10. brainstate/compile/_loop_no_collection.py +2 -0
  11. brainstate/compile/_make_jaxpr.py +7 -3
  12. brainstate/compile/_progress_bar.py +68 -40
  13. brainstate/compile/_unvmap.py +6 -3
  14. brainstate/event/__init__.py +0 -2
  15. brainstate/event/_csr.py +266 -23
  16. brainstate/event/_csr_test.py +187 -0
  17. brainstate/event/_xla_custom_op.py +7 -3
  18. brainstate/graph/__init__.py +8 -12
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_interaction/_conv.py +4 -2
  24. brainstate/nn/_interaction/_linear.py +84 -10
  25. brainstate/random/_rand_funs.py +9 -2
  26. brainstate/random/_rand_seed.py +12 -2
  27. brainstate/random/_rand_state.py +50 -179
  28. brainstate/surrogate.py +5 -1
  29. brainstate/util/__init__.py +0 -4
  30. brainstate/util/_caller.py +1 -1
  31. brainstate/util/_dict.py +4 -1
  32. brainstate/util/_filter.py +1 -1
  33. brainstate/util/_pretty_repr.py +1 -1
  34. brainstate/util/_struct.py +1 -1
  35. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
  37. brainstate/event/_csr_mv_test.py +0 -118
  38. brainstate/graph/_graph_context.py +0 -443
  39. brainstate/graph/_graph_context_test.py +0 -65
  40. brainstate/graph/_graph_convert.py +0 -246
  41. brainstate/util/_tracers.py +0 -68
  42. brainstate/util/_visualization.py +0 -47
  43. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  44. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  46. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- from collections import namedtuple
20
19
  from functools import partial
21
20
  from operator import index
22
21
  from typing import Optional
@@ -37,6 +36,8 @@ from ._random_for_unit import uniform_for_unit, permutation_for_unit
37
36
 
38
37
  __all__ = ['RandomState', 'DEFAULT', ]
39
38
 
39
+ use_prng_key = True
40
+
40
41
 
41
42
  class RandomState(State):
42
43
  """RandomState that track the random generator state. """
@@ -56,12 +57,15 @@ class RandomState(State):
56
57
  if seed_or_key is None:
57
58
  seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
58
59
  if isinstance(seed_or_key, int):
59
- key = jr.PRNGKey(seed_or_key)
60
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
60
61
  else:
61
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
62
- raise ValueError('key must be an array with dtype uint32. '
63
- f'But we got {seed_or_key}')
64
- key = seed_or_key
62
+ if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
63
+ key = seed_or_key
64
+ else:
65
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
66
+ raise ValueError('key must be an array with dtype uint32. '
67
+ f'But we got {seed_or_key}')
68
+ key = seed_or_key
65
69
  super().__init__(key)
66
70
 
67
71
  self._backup = None
@@ -70,6 +74,9 @@ class RandomState(State):
70
74
  return f'{self.__class__.__name__}({self.value})'
71
75
 
72
76
  def check_if_deleted(self):
77
+ if not use_prng_key and isinstance(self._value, np.ndarray):
78
+ self._value = jr.key(np.random.randint(0, 10000))
79
+
73
80
  if (
74
81
  isinstance(self._value, jax.Array) and
75
82
  not isinstance(self._value, jax.core.Tracer) and
@@ -111,12 +118,19 @@ class RandomState(State):
111
118
  if seed_or_key is None:
112
119
  seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
113
120
  if np.size(seed_or_key) == 1:
114
- key = jr.PRNGKey(seed_or_key)
121
+ if isinstance(seed_or_key, int):
122
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
123
+ elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
124
+ key = seed_or_key
125
+ elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
126
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
127
+ else:
128
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
115
129
  else:
116
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
117
- raise ValueError('key must be an array with dtype uint32. '
118
- f'But we got {seed_or_key}')
119
- key = seed_or_key
130
+ if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
131
+ key = seed_or_key
132
+ else:
133
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
120
134
  self.value = key
121
135
 
122
136
  def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
@@ -560,15 +574,15 @@ class RandomState(State):
560
574
  )
561
575
  return out if unit.is_unitless else u.Quantity(out, unit=unit)
562
576
 
563
- def _check_p(self, p):
564
- raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
577
+ def _check_p(self, *args, **kwargs):
578
+ raise ValueError('Parameter p should be within [0, 1], but we got {p}')
565
579
 
566
580
  def bernoulli(self,
567
581
  p,
568
582
  size: Optional[Size] = None,
569
583
  key: Optional[SeedOrKey] = None):
570
584
  p = _check_py_seq(p)
571
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
585
+ jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
572
586
  if size is None:
573
587
  size = jnp.shape(p)
574
588
  key = self.split_key() if key is None else _formalize_key(key)
@@ -603,19 +617,27 @@ class RandomState(State):
603
617
  samples = jnp.exp(samples)
604
618
  return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
605
619
 
606
- def binomial(self,
607
- n,
608
- p,
609
- size: Optional[Size] = None,
610
- key: Optional[SeedOrKey] = None,
611
- dtype: DTypeLike = None):
620
+ def binomial(
621
+ self,
622
+ n,
623
+ p,
624
+ size: Optional[Size] = None,
625
+ key: Optional[SeedOrKey] = None,
626
+ dtype: DTypeLike = None,
627
+ check_valid: bool = True,
628
+ ):
612
629
  n = _check_py_seq(n)
613
630
  p = _check_py_seq(p)
614
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
631
+ if check_valid:
632
+ jit_error_if(
633
+ jnp.any(jnp.logical_or(p < 0, p > 1)),
634
+ 'Parameter p should be within [0, 1], but we got {p}',
635
+ p=p
636
+ )
615
637
  if size is None:
616
638
  size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
617
639
  key = self.split_key() if key is None else _formalize_key(key)
618
- r = _binomial(key, p, n, shape=_size2shape(size))
640
+ r = jr.binomial(key, n, p, shape=_size2shape(size))
619
641
  dtype = dtype or environ.ditype()
620
642
  return jnp.asarray(r, dtype=dtype)
621
643
 
@@ -1142,8 +1164,13 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1142
1164
 
1143
1165
  def _formalize_key(key):
1144
1166
  if isinstance(key, int):
1145
- return jr.PRNGKey(key)
1167
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1146
1168
  elif isinstance(key, (jax.Array, np.ndarray)):
1169
+ if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
1170
+ return key
1171
+ if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
1172
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1173
+
1147
1174
  if key.dtype != jnp.uint32:
1148
1175
  raise TypeError('key must be a int or an array with two uint32.')
1149
1176
  if key.size != 2:
@@ -1216,162 +1243,6 @@ def _const(example, val):
1216
1243
  return np.array(val, dtype)
1217
1244
 
1218
1245
 
1219
- _tr_params = namedtuple(
1220
- "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
1221
- )
1222
-
1223
-
1224
- def _get_tr_params(n, p):
1225
- # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
1226
- # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
1227
- mu = n * p
1228
- spq = jnp.sqrt(mu * (1 - p))
1229
- c = mu + 0.5
1230
- b = 1.15 + 2.53 * spq
1231
- a = -0.0873 + 0.0248 * b + 0.01 * p
1232
- alpha = (2.83 + 5.1 / b) * spq
1233
- u_r = 0.43
1234
- v_r = 0.92 - 4.2 / b
1235
- m = jnp.floor((n + 1) * p).astype(n.dtype)
1236
- log_p = jnp.log(p)
1237
- log1_p = jnp.log1p(-p)
1238
- log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
1239
- _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
1240
- return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
1241
-
1242
-
1243
- def _stirling_approx_tail(k):
1244
- precomputed = jnp.array([0.08106146679532726,
1245
- 0.04134069595540929,
1246
- 0.02767792568499834,
1247
- 0.02079067210376509,
1248
- 0.01664469118982119,
1249
- 0.01387612882307075,
1250
- 0.01189670994589177,
1251
- 0.01041126526197209,
1252
- 0.009255462182712733,
1253
- 0.008330563433362871],
1254
- dtype=environ.dftype())
1255
- kp1 = k + 1
1256
- kp1sq = (k + 1) ** 2
1257
- return jnp.where(k < 10,
1258
- precomputed[k],
1259
- (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
1260
-
1261
-
1262
- def _binomial_btrs(key, p, n):
1263
- """
1264
- Based on the transformed rejection sampling algorithm (BTRS) from the
1265
- following reference:
1266
-
1267
- Hormann, "The Generation of Binonmial Random Variates"
1268
- (https://core.ac.uk/download/pdf/11007254.pdf)
1269
- """
1270
-
1271
- def _btrs_body_fn(val):
1272
- _, key, _, _ = val
1273
- key, key_u, key_v = jr.split(key, 3)
1274
- u = jr.uniform(key_u)
1275
- v = jr.uniform(key_v)
1276
- u = u - 0.5
1277
- k = jnp.floor(
1278
- (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
1279
- ).astype(n.dtype)
1280
- return k, key, u, v
1281
-
1282
- def _btrs_cond_fn(val):
1283
- def accept_fn(k, u, v):
1284
- # See acceptance condition in Step 3. (Page 3) of TRS algorithm
1285
- # v <= f(k) * g_grad(u) / alpha
1286
-
1287
- m = tr_params.m
1288
- log_p = tr_params.log_p
1289
- log1_p = tr_params.log1_p
1290
- # See: formula for log(f(k)) at bottom of Page 5.
1291
- log_f = (
1292
- (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
1293
- + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
1294
- + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
1295
- + tr_params.log_h
1296
- )
1297
- g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
1298
- return jnp.log((v * tr_params.alpha) / g) <= log_f
1299
-
1300
- k, key, u, v = val
1301
- early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
1302
- early_reject = (k < 0) | (k > n)
1303
- return lax.cond(
1304
- early_accept | early_reject,
1305
- (),
1306
- lambda _: ~early_accept,
1307
- (k, u, v),
1308
- lambda x: ~accept_fn(*x),
1309
- )
1310
-
1311
- tr_params = _get_tr_params(n, p)
1312
- ret = lax.while_loop(
1313
- _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
1314
- ) # use k=-1 initially so that cond_fn returns True
1315
- return ret[0]
1316
-
1317
-
1318
- def _binomial_inversion(key, p, n):
1319
- def _binom_inv_body_fn(val):
1320
- i, key, geom_acc = val
1321
- key, key_u = jr.split(key)
1322
- u = jr.uniform(key_u)
1323
- geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
1324
- geom_acc = geom_acc + geom
1325
- return i + 1, key, geom_acc
1326
-
1327
- def _binom_inv_cond_fn(val):
1328
- i, _, geom_acc = val
1329
- return geom_acc <= n
1330
-
1331
- log1_p = jnp.log1p(-p)
1332
- ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
1333
- return ret[0]
1334
-
1335
-
1336
- def _binomial_dispatch(key, p, n):
1337
- def dispatch(key, p, n):
1338
- is_le_mid = p <= 0.5
1339
- pq = jnp.where(is_le_mid, p, 1 - p)
1340
- mu = n * pq
1341
- k = lax.cond(
1342
- mu < 10,
1343
- (key, pq, n),
1344
- lambda x: _binomial_inversion(*x),
1345
- (key, pq, n),
1346
- lambda x: _binomial_btrs(*x),
1347
- )
1348
- return jnp.where(is_le_mid, k, n - k)
1349
-
1350
- # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
1351
- cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
1352
- return lax.cond(
1353
- cond0 & (p < 1),
1354
- (key, p, n),
1355
- lambda x: dispatch(*x),
1356
- (),
1357
- lambda _: jnp.where(cond0, n, 0),
1358
- )
1359
-
1360
-
1361
- @partial(jit, static_argnums=(3,))
1362
- def _binomial(key, p, n, shape):
1363
- shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
1364
- # reshape to map over axis 0
1365
- p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
1366
- n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
1367
- key = jr.split(key, jnp.size(p))
1368
- if jax.default_backend() == "cpu":
1369
- ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
1370
- else:
1371
- ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
1372
- return jnp.reshape(ret, shape)
1373
-
1374
-
1375
1246
  @partial(jit, static_argnums=(2,))
1376
1247
  def _categorical(key, p, shape):
1377
1248
  # this implementation is fast when event shape is small, and slow otherwise
brainstate/surrogate.py CHANGED
@@ -19,9 +19,13 @@ from __future__ import annotations
19
19
  import jax
20
20
  import jax.numpy as jnp
21
21
  import jax.scipy as sci
22
- from jax.core import Primitive
23
22
  from jax.interpreters import batching, ad, mlir
24
23
 
24
+ if jax.__version_info__ < (0, 4, 38):
25
+ from jax.core import Primitive
26
+ else:
27
+ from jax.extend.core import Primitive
28
+
25
29
  __all__ = [
26
30
  'Surrogate',
27
31
  'Sigmoid',
@@ -27,8 +27,6 @@ from ._scaling import *
27
27
  from ._scaling import __all__ as _mem_scale_all
28
28
  from ._struct import *
29
29
  from ._struct import __all__ as _struct_all
30
- from ._visualization import *
31
- from ._visualization import __all__ as _visualization_all
32
30
 
33
31
  __all__ = (
34
32
  _others_all
@@ -38,7 +36,6 @@ __all__ = (
38
36
  + _struct_all
39
37
  + _error_all
40
38
  + _mapping_all
41
- + _visualization_all
42
39
  )
43
40
  del (
44
41
  _others_all,
@@ -48,5 +45,4 @@ del (
48
45
  _struct_all,
49
46
  _error_all,
50
47
  _mapping_all,
51
- _visualization_all,
52
48
  )
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
brainstate/util/_dict.py CHANGED
@@ -1,4 +1,7 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
2
5
  #
3
6
  # Licensed under the Apache License, Version 2.0 (the "License");
4
7
  # you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250105
3
+ Version: 0.1.0.post20250120
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.9
19
19
  Classifier: Programming Language :: Python :: 3.10
20
20
  Classifier: Programming Language :: Python :: 3.11
21
21
  Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Programming Language :: Python :: 3.13
22
23
  Classifier: License :: OSI Approved :: Apache Software License
23
24
  Classifier: Programming Language :: Python
24
25
  Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
@@ -1,46 +1,45 @@
1
1
  brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
2
- brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
3
- brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
2
+ brainstate/_state.py,sha256=GZ46liHZSHbAHQEuELvOeoJ27P9xiZDz06G2AASjAjA,29142
3
+ brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
5
  brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
6
6
  brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
7
7
  brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
8
8
  brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
9
- brainstate/surrogate.py,sha256=YaY6RJ6kzpuPXWFjaWsxWt2MzJfdm5v_jeOR8V_jPoU,48369
9
+ brainstate/surrogate.py,sha256=t4SzVwUVMAPtC-O1vFbuE9F4265wgAv7ud77ufIJsuk,48464
10
10
  brainstate/transform.py,sha256=cxbymTlJ6uHvJWEEYXzFUkAySs_TbUTHakt0NQgWJ3s,808
11
11
  brainstate/typing.py,sha256=Qh-LBzm6oG4rSXv4V5qB8SNYcoOR7bASoK_iQxnlafk,10467
12
12
  brainstate/augment/__init__.py,sha256=BtXIBel7GbttmfBX6grxOxl0IiOJxLEa7qCGAXumamE,1286
13
13
  brainstate/augment/_autograd.py,sha256=o9ivoEY7BmtdM1XmzdMmeRXpj6Tvn5xNB8LSGp2HKC8,25238
14
14
  brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9ptJAGKo,44060
15
- brainstate/augment/_eval_shape.py,sha256=dGlRVHOAZ9LSRZsFi1erxgEWHrnhBO3Kq3WW11-Hvng,3819
16
- brainstate/augment/_eval_shape_test.py,sha256=1nnxbU7hPRbZPQWNWbQ518pw-H7FGDKKnQpZGBY9uRI,1390
17
- brainstate/augment/_mapping.py,sha256=cpxzVGCEYnP5jPqrowYoPXciw_-QR2F3wggrRj1OCPc,21850
18
- brainstate/augment/_mapping_test.py,sha256=TEAecjZmTSDCfxARgrzcDJ2dW1Yz_sCITmFiA9FGrhk,9455
15
+ brainstate/augment/_eval_shape.py,sha256=ObCgsZ704kLduB1dbjJZh5nVQYEkLR5ebK74V5NV42k,3892
16
+ brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
17
+ brainstate/augment/_mapping.py,sha256=nU6Y7fSnYXyQSILXU2QT-O73Fm3pnwOmgUoDaHqjve8,21544
18
+ brainstate/augment/_mapping_test.py,sha256=_KFhE3CXItwpbZ1gJfrDu3yUtX0YbfPUuHJG_G_BXEs,8963
19
19
  brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
20
20
  brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
21
- brainstate/compile/_ad_checkpoint.py,sha256=5zJ1ENeTU4FzRY_uNpr85NhKfuicMMjcIbhu6-bSM4k,9451
21
+ brainstate/compile/_ad_checkpoint.py,sha256=K6I4vnznDsqC9cUeCnez9UdV9r_toGA3zHezoHLA6mI,9377
22
22
  brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
23
23
  brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
24
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
25
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
26
26
  brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
27
- brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
27
+ brainstate/compile/_jit.py,sha256=itAWENKfJvnlaWl_uSy8lHTK8K1in89F_ZXXwp-EGRM,13944
28
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
29
- brainstate/compile/_loop_collect_return.py,sha256=XwMnKkMH0xTWB1f6GE4NQNK1R2GXTXCiVgulpkdIpc4,23308
29
+ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzjlyk0Te53-AvOE,23492
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
- brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
31
+ brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
33
+ brainstate/compile/_make_jaxpr.py,sha256=DQf_80w3p0wi2Gb9P6_tLMJ0Oadgyr_jWkVjus0MSjw,33205
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
35
- brainstate/compile/_progress_bar.py,sha256=eInZPjiqzYE6PWxl_or_lBthDNDO0Ov60Uz0DbuBbZQ,4620
36
- brainstate/compile/_unvmap.py,sha256=0i-NvCLDAUe-effJIIEPVsK4WTPbCDBTgw6AqRvq7mE,4163
35
+ brainstate/compile/_progress_bar.py,sha256=0oVlZ4kW_ZMciJjOR_ebj3PNe_XkCMkoQpv-HUUdoF0,5554
36
+ brainstate/compile/_unvmap.py,sha256=EY4rbqCzzPOiaRwpWTiyBwb5dVkYFnacHhBZUZObxPI,4255
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
- brainstate/event/__init__.py,sha256=W0ZxbgrcFuYhWTl-GZ0UDoMGfsWmesvG4J_LbTBkex8,937
39
- brainstate/event/_csr.py,sha256=QDccbgXUklE2iq1w6WdyaFspXY1165uA9UlPltX16OU,30365
38
+ brainstate/event/__init__.py,sha256=gSEem-1oTHgy99Mjm3uumTXVd93tLVl0c4dUgRpoifk,895
39
+ brainstate/event/_csr.py,sha256=PYKw8CGNgQ24MxQDoeBZTrPuC7Z-GetXQld9KiTbNYw,40063
40
+ brainstate/event/_csr_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
40
41
  brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,10266
41
- brainstate/event/_csr_mv_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
42
- brainstate/event/_csr_mv_test.py,sha256=WQfAvp_3UeCUGAZjr3_aqQvrB-eYZcFEN4v1PBe9fUQ,4012
43
- brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
42
+ brainstate/event/_csr_test.py,sha256=_iXwUFq90GU7npVOUnlI4NA27RJ8zyCZBxe7NDH803o,9533
44
43
  brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
45
44
  brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
46
45
  brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
@@ -48,7 +47,7 @@ brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV
48
47
  brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
49
48
  brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
50
49
  brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
51
- brainstate/event/_xla_custom_op.py,sha256=f0OrO6CjsJOUNUCRPpHIRmsb_wgNEym0xBl1tcz8ij4,11016
50
+ brainstate/event/_xla_custom_op.py,sha256=wF_nKgLUv1IGd8OY89MYqIvyZITl8UcrVysJWFugJxY,11093
52
51
  brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
53
52
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
54
53
  brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
@@ -56,14 +55,11 @@ brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv
56
55
  brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
57
56
  brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
58
57
  brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
59
- brainstate/graph/__init__.py,sha256=ZUD-gY9txkr0yQ7VRqRSptzGMX3J_ZZ_VsJkxp8EfpY,1334
60
- brainstate/graph/_graph_context.py,sha256=J3WmCPrNyYiv-i75QSWQQdg0qZ6jmRy3gIzz3SuDBLI,15429
61
- brainstate/graph/_graph_context_test.py,sha256=IYpjqbXwSFF65XL0ZbdPeC1jYyEHLpQVrhuFeJXH4GM,2409
62
- brainstate/graph/_graph_convert.py,sha256=llSREtGQrIggkD0wmxUbYKuSveLW4ihDZME6Ab-mRTQ,9147
63
- brainstate/graph/_graph_node.py,sha256=mmZ0jhZev8ReNJhVLgWqYJEedEDtJHxhwxRv4ytQVNo,9268
58
+ brainstate/graph/__init__.py,sha256=fyvQMlAUY3QYTzvDzz5TDoWS2XQwZ6P3ic6BtysZyHM,1026
59
+ brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
64
60
  brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
65
- brainstate/graph/_graph_operation.py,sha256=xM-pvYMLHb8e2lAmDVgTTThAf5gd4hHymfrqOJwpHeo,64132
66
- brainstate/graph/_graph_operation_test.py,sha256=ADyyuMk2xidEkkFNpGvUbvEtRmUj-tqOI4cF3eRuakM,24678
61
+ brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
62
+ brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
67
63
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
68
64
  brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
69
65
  brainstate/init/_generic.py,sha256=sGOvd_atpxLWqqZKobTfAiMiYRnDC19PBNHdQy_igFM,8028
@@ -83,7 +79,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
83
79
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
84
80
  brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
85
81
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
86
- brainstate/nn/_dyn_impl/_inputs.py,sha256=pkcAVt_o5kQF_BGCTZZ-NUQpHgjlFHHPwtYC0fJkAA0,9099
82
+ brainstate/nn/_dyn_impl/_inputs.py,sha256=UNoGxKIKXwPnhelljDowqAWlV6ds7aBBkEbvdy2oDI4,11302
87
83
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
88
84
  brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
89
85
  brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
@@ -102,10 +98,10 @@ brainstate/nn/_elementwise/_dropout_test.py,sha256=ZzNvjFf46NpKWGBIcT6O0lKOBGpxO
102
98
  brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
103
99
  brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
104
100
  brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
105
- brainstate/nn/_interaction/_conv.py,sha256=LgWO4TeKRru07UEUga3YX6xog6WHtOvKdKtgxGnHUvw,18512
101
+ brainstate/nn/_interaction/_conv.py,sha256=lwyxTVsJVPiKlZcgB6iqE64aX7AOJzplDSj4y6-m18o,18592
106
102
  brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
107
103
  brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
108
- brainstate/nn/_interaction/_linear.py,sha256=0G3OOmbgBwB22JAb2AWr4btvBOrFHBi4xWYKkd0SpOk,12148
104
+ brainstate/nn/_interaction/_linear.py,sha256=EnkOk1oE79rvRIjU6HBllxUpVOEcQQCj4vtavo9AJjI,14767
109
105
  brainstate/nn/_interaction/_linear_test.py,sha256=QfCR8SBBed9OnSY-AmQ0kJqoggDA3Xem0dRJ0BusxLU,2872
110
106
  brainstate/nn/_interaction/_normalizations.py,sha256=7YDzkmO_iqd70fH_wawb60Bu8eGOdvZq23emP-b68Hc,37440
111
107
  brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
@@ -119,26 +115,24 @@ brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSML
119
115
  brainstate/optim/_optax_optimizer_test.py,sha256=DAomE8Eu3dn4gh1S3EZ_u4pW4rhcl16vWPbnDcN3Rs4,1762
120
116
  brainstate/optim/_sgd_optimizer.py,sha256=NVKYhGcw2D1ksNWUIXZcj-74LUaan8XL3EERk-EHMRI,46008
121
117
  brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
122
- brainstate/random/_rand_funs.py,sha256=y3K39RZLoMAPv3mllnZNLGYTI30EHg6fFzlxAWOzvt0,137597
118
+ brainstate/random/_rand_funs.py,sha256=WaelvEpeQb6Vuqt4eNgsAtd7GI8BqgEdVYbXgtCOd54,137682
123
119
  brainstate/random/_rand_funs_test.py,sha256=abO5lSoPBgBcg6ecFE1qnCg98__QGa68GSYC5pQW5QI,19438
124
- brainstate/random/_rand_seed.py,sha256=dfq-_vb-4YEEWL3Bkcm_VaVkxU2bkuOsIs3NlZy4BeE,5480
120
+ brainstate/random/_rand_seed.py,sha256=MHA9znbdJW9ujx73onDRrAOI684_0FmGfqczBsSXYQg,5985
125
121
  brainstate/random/_rand_seed_test.py,sha256=Qibcs-ZqCvj1LuucmQ8H00B_HBNhf2f6un0aUdNZNTw,1518
126
- brainstate/random/_rand_state.py,sha256=EDTg4QCM1HxaLd_2F31mu-qTlQR5NFPDiKvI-QhgnOQ,59444
122
+ brainstate/random/_rand_state.py,sha256=nuoQ8GU1MfJPRNN-ZmRQsggVjoyPhaEdZmwM7_4-Q3c,55206
127
123
  brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
128
- brainstate/util/__init__.py,sha256=EpAGQiukTUIVJWCfp106-aU1-jo7cIsCBBlvweXmwYY,1643
129
- brainstate/util/_caller.py,sha256=88-vWwIjOvNBIHkZSrBKfImeCJbyIPJqZ16pceC83yQ,2803
130
- brainstate/util/_dict.py,sha256=jv4QuHGNY3mgvmyS_2j6-HSex1lwcGlIjgaCACYpCtg,26077
124
+ brainstate/util/__init__.py,sha256=-FWEuSKXG3mWxYphGFAy3UEuVe39lFs1GruluzdXDoI,1502
125
+ brainstate/util/_caller.py,sha256=T3bzu7-09r-6EOrU6Muca_aMXSQua_X2lXjEqb-w39w,2782
126
+ brainstate/util/_dict.py,sha256=Yapug-_RZQYjvd8cZ3v90_MX7rUYJDBzBnZJT6a0NXY,26178
131
127
  brainstate/util/_dict_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
132
128
  brainstate/util/_error.py,sha256=eyZ8PGFixqe2K5OEfjSDzI-2tU0ieYQoUpBP7yStlPQ,878
133
- brainstate/util/_filter.py,sha256=Jy_lUu1EC4HWgbJFkY6nhsTRc3xqONwZq36qopl7wRQ,4991
129
+ brainstate/util/_filter.py,sha256=1-bvFHdjeehvXeHTrCEp8xr25lopKe8d3XZGCNegq0s,4970
134
130
  brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14762
135
- brainstate/util/_pretty_repr.py,sha256=NYEBCo2iz9Potx-IR7uZZzt2aLQW_94vH79fGusiC2A,5737
131
+ brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
136
132
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
137
- brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
138
- brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
139
- brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
140
- brainstate-0.1.0.post20250105.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
- brainstate-0.1.0.post20250105.dist-info/METADATA,sha256=Xec1GNBlHcignyvym-EHzU-JIOUuo3T-IUU2LoCO0sk,3533
142
- brainstate-0.1.0.post20250105.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
- brainstate-0.1.0.post20250105.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
- brainstate-0.1.0.post20250105.dist-info/RECORD,,
133
+ brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
134
+ brainstate-0.1.0.post20250120.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
135
+ brainstate-0.1.0.post20250120.dist-info/METADATA,sha256=vUyr4XjiyAW68waFKMray9EEFHTqjqRp5GlqAG8LsKY,3585
136
+ brainstate-0.1.0.post20250120.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
137
+ brainstate-0.1.0.post20250120.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
138
+ brainstate-0.1.0.post20250120.dist-info/RECORD,,