brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +6 -3
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -3
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_xla_custom_op.py +7 -3
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
brainstate/random/_rand_state.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from collections import namedtuple
|
20
19
|
from functools import partial
|
21
20
|
from operator import index
|
22
21
|
from typing import Optional
|
@@ -37,6 +36,8 @@ from ._random_for_unit import uniform_for_unit, permutation_for_unit
|
|
37
36
|
|
38
37
|
__all__ = ['RandomState', 'DEFAULT', ]
|
39
38
|
|
39
|
+
use_prng_key = True
|
40
|
+
|
40
41
|
|
41
42
|
class RandomState(State):
|
42
43
|
"""RandomState that track the random generator state. """
|
@@ -56,12 +57,15 @@ class RandomState(State):
|
|
56
57
|
if seed_or_key is None:
|
57
58
|
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
58
59
|
if isinstance(seed_or_key, int):
|
59
|
-
key = jr.PRNGKey(seed_or_key)
|
60
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
60
61
|
else:
|
61
|
-
if
|
62
|
-
|
63
|
-
|
64
|
-
|
62
|
+
if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
63
|
+
key = seed_or_key
|
64
|
+
else:
|
65
|
+
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
|
66
|
+
raise ValueError('key must be an array with dtype uint32. '
|
67
|
+
f'But we got {seed_or_key}')
|
68
|
+
key = seed_or_key
|
65
69
|
super().__init__(key)
|
66
70
|
|
67
71
|
self._backup = None
|
@@ -70,6 +74,9 @@ class RandomState(State):
|
|
70
74
|
return f'{self.__class__.__name__}({self.value})'
|
71
75
|
|
72
76
|
def check_if_deleted(self):
|
77
|
+
if not use_prng_key and isinstance(self._value, np.ndarray):
|
78
|
+
self._value = jr.key(np.random.randint(0, 10000))
|
79
|
+
|
73
80
|
if (
|
74
81
|
isinstance(self._value, jax.Array) and
|
75
82
|
not isinstance(self._value, jax.core.Tracer) and
|
@@ -111,12 +118,19 @@ class RandomState(State):
|
|
111
118
|
if seed_or_key is None:
|
112
119
|
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
113
120
|
if np.size(seed_or_key) == 1:
|
114
|
-
|
121
|
+
if isinstance(seed_or_key, int):
|
122
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
123
|
+
elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
124
|
+
key = seed_or_key
|
125
|
+
elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
|
126
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
127
|
+
else:
|
128
|
+
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
115
129
|
else:
|
116
|
-
if len(seed_or_key)
|
117
|
-
|
118
|
-
|
119
|
-
|
130
|
+
if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
|
131
|
+
key = seed_or_key
|
132
|
+
else:
|
133
|
+
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
120
134
|
self.value = key
|
121
135
|
|
122
136
|
def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
|
@@ -560,15 +574,15 @@ class RandomState(State):
|
|
560
574
|
)
|
561
575
|
return out if unit.is_unitless else u.Quantity(out, unit=unit)
|
562
576
|
|
563
|
-
def _check_p(self,
|
564
|
-
raise ValueError(
|
577
|
+
def _check_p(self, *args, **kwargs):
|
578
|
+
raise ValueError('Parameter p should be within [0, 1], but we got {p}')
|
565
579
|
|
566
580
|
def bernoulli(self,
|
567
581
|
p,
|
568
582
|
size: Optional[Size] = None,
|
569
583
|
key: Optional[SeedOrKey] = None):
|
570
584
|
p = _check_py_seq(p)
|
571
|
-
jit_error_if(jnp.any(jnp.
|
585
|
+
jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
|
572
586
|
if size is None:
|
573
587
|
size = jnp.shape(p)
|
574
588
|
key = self.split_key() if key is None else _formalize_key(key)
|
@@ -603,19 +617,27 @@ class RandomState(State):
|
|
603
617
|
samples = jnp.exp(samples)
|
604
618
|
return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
|
605
619
|
|
606
|
-
def binomial(
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
620
|
+
def binomial(
|
621
|
+
self,
|
622
|
+
n,
|
623
|
+
p,
|
624
|
+
size: Optional[Size] = None,
|
625
|
+
key: Optional[SeedOrKey] = None,
|
626
|
+
dtype: DTypeLike = None,
|
627
|
+
check_valid: bool = True,
|
628
|
+
):
|
612
629
|
n = _check_py_seq(n)
|
613
630
|
p = _check_py_seq(p)
|
614
|
-
|
631
|
+
if check_valid:
|
632
|
+
jit_error_if(
|
633
|
+
jnp.any(jnp.logical_or(p < 0, p > 1)),
|
634
|
+
'Parameter p should be within [0, 1], but we got {p}',
|
635
|
+
p=p
|
636
|
+
)
|
615
637
|
if size is None:
|
616
638
|
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
|
617
639
|
key = self.split_key() if key is None else _formalize_key(key)
|
618
|
-
r =
|
640
|
+
r = jr.binomial(key, n, p, shape=_size2shape(size))
|
619
641
|
dtype = dtype or environ.ditype()
|
620
642
|
return jnp.asarray(r, dtype=dtype)
|
621
643
|
|
@@ -1142,8 +1164,13 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
|
|
1142
1164
|
|
1143
1165
|
def _formalize_key(key):
|
1144
1166
|
if isinstance(key, int):
|
1145
|
-
return jr.PRNGKey(key)
|
1167
|
+
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1146
1168
|
elif isinstance(key, (jax.Array, np.ndarray)):
|
1169
|
+
if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
|
1170
|
+
return key
|
1171
|
+
if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
|
1172
|
+
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1173
|
+
|
1147
1174
|
if key.dtype != jnp.uint32:
|
1148
1175
|
raise TypeError('key must be a int or an array with two uint32.')
|
1149
1176
|
if key.size != 2:
|
@@ -1216,162 +1243,6 @@ def _const(example, val):
|
|
1216
1243
|
return np.array(val, dtype)
|
1217
1244
|
|
1218
1245
|
|
1219
|
-
_tr_params = namedtuple(
|
1220
|
-
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
|
1221
|
-
)
|
1222
|
-
|
1223
|
-
|
1224
|
-
def _get_tr_params(n, p):
|
1225
|
-
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
|
1226
|
-
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
|
1227
|
-
mu = n * p
|
1228
|
-
spq = jnp.sqrt(mu * (1 - p))
|
1229
|
-
c = mu + 0.5
|
1230
|
-
b = 1.15 + 2.53 * spq
|
1231
|
-
a = -0.0873 + 0.0248 * b + 0.01 * p
|
1232
|
-
alpha = (2.83 + 5.1 / b) * spq
|
1233
|
-
u_r = 0.43
|
1234
|
-
v_r = 0.92 - 4.2 / b
|
1235
|
-
m = jnp.floor((n + 1) * p).astype(n.dtype)
|
1236
|
-
log_p = jnp.log(p)
|
1237
|
-
log1_p = jnp.log1p(-p)
|
1238
|
-
log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
|
1239
|
-
_stirling_approx_tail(m) + _stirling_approx_tail(n - m))
|
1240
|
-
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
|
1241
|
-
|
1242
|
-
|
1243
|
-
def _stirling_approx_tail(k):
|
1244
|
-
precomputed = jnp.array([0.08106146679532726,
|
1245
|
-
0.04134069595540929,
|
1246
|
-
0.02767792568499834,
|
1247
|
-
0.02079067210376509,
|
1248
|
-
0.01664469118982119,
|
1249
|
-
0.01387612882307075,
|
1250
|
-
0.01189670994589177,
|
1251
|
-
0.01041126526197209,
|
1252
|
-
0.009255462182712733,
|
1253
|
-
0.008330563433362871],
|
1254
|
-
dtype=environ.dftype())
|
1255
|
-
kp1 = k + 1
|
1256
|
-
kp1sq = (k + 1) ** 2
|
1257
|
-
return jnp.where(k < 10,
|
1258
|
-
precomputed[k],
|
1259
|
-
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
|
1260
|
-
|
1261
|
-
|
1262
|
-
def _binomial_btrs(key, p, n):
|
1263
|
-
"""
|
1264
|
-
Based on the transformed rejection sampling algorithm (BTRS) from the
|
1265
|
-
following reference:
|
1266
|
-
|
1267
|
-
Hormann, "The Generation of Binonmial Random Variates"
|
1268
|
-
(https://core.ac.uk/download/pdf/11007254.pdf)
|
1269
|
-
"""
|
1270
|
-
|
1271
|
-
def _btrs_body_fn(val):
|
1272
|
-
_, key, _, _ = val
|
1273
|
-
key, key_u, key_v = jr.split(key, 3)
|
1274
|
-
u = jr.uniform(key_u)
|
1275
|
-
v = jr.uniform(key_v)
|
1276
|
-
u = u - 0.5
|
1277
|
-
k = jnp.floor(
|
1278
|
-
(2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
|
1279
|
-
).astype(n.dtype)
|
1280
|
-
return k, key, u, v
|
1281
|
-
|
1282
|
-
def _btrs_cond_fn(val):
|
1283
|
-
def accept_fn(k, u, v):
|
1284
|
-
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
|
1285
|
-
# v <= f(k) * g_grad(u) / alpha
|
1286
|
-
|
1287
|
-
m = tr_params.m
|
1288
|
-
log_p = tr_params.log_p
|
1289
|
-
log1_p = tr_params.log1_p
|
1290
|
-
# See: formula for log(f(k)) at bottom of Page 5.
|
1291
|
-
log_f = (
|
1292
|
-
(n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
|
1293
|
-
+ (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
|
1294
|
-
+ (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
|
1295
|
-
+ tr_params.log_h
|
1296
|
-
)
|
1297
|
-
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
|
1298
|
-
return jnp.log((v * tr_params.alpha) / g) <= log_f
|
1299
|
-
|
1300
|
-
k, key, u, v = val
|
1301
|
-
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
|
1302
|
-
early_reject = (k < 0) | (k > n)
|
1303
|
-
return lax.cond(
|
1304
|
-
early_accept | early_reject,
|
1305
|
-
(),
|
1306
|
-
lambda _: ~early_accept,
|
1307
|
-
(k, u, v),
|
1308
|
-
lambda x: ~accept_fn(*x),
|
1309
|
-
)
|
1310
|
-
|
1311
|
-
tr_params = _get_tr_params(n, p)
|
1312
|
-
ret = lax.while_loop(
|
1313
|
-
_btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
|
1314
|
-
) # use k=-1 initially so that cond_fn returns True
|
1315
|
-
return ret[0]
|
1316
|
-
|
1317
|
-
|
1318
|
-
def _binomial_inversion(key, p, n):
|
1319
|
-
def _binom_inv_body_fn(val):
|
1320
|
-
i, key, geom_acc = val
|
1321
|
-
key, key_u = jr.split(key)
|
1322
|
-
u = jr.uniform(key_u)
|
1323
|
-
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
|
1324
|
-
geom_acc = geom_acc + geom
|
1325
|
-
return i + 1, key, geom_acc
|
1326
|
-
|
1327
|
-
def _binom_inv_cond_fn(val):
|
1328
|
-
i, _, geom_acc = val
|
1329
|
-
return geom_acc <= n
|
1330
|
-
|
1331
|
-
log1_p = jnp.log1p(-p)
|
1332
|
-
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
|
1333
|
-
return ret[0]
|
1334
|
-
|
1335
|
-
|
1336
|
-
def _binomial_dispatch(key, p, n):
|
1337
|
-
def dispatch(key, p, n):
|
1338
|
-
is_le_mid = p <= 0.5
|
1339
|
-
pq = jnp.where(is_le_mid, p, 1 - p)
|
1340
|
-
mu = n * pq
|
1341
|
-
k = lax.cond(
|
1342
|
-
mu < 10,
|
1343
|
-
(key, pq, n),
|
1344
|
-
lambda x: _binomial_inversion(*x),
|
1345
|
-
(key, pq, n),
|
1346
|
-
lambda x: _binomial_btrs(*x),
|
1347
|
-
)
|
1348
|
-
return jnp.where(is_le_mid, k, n - k)
|
1349
|
-
|
1350
|
-
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
|
1351
|
-
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
|
1352
|
-
return lax.cond(
|
1353
|
-
cond0 & (p < 1),
|
1354
|
-
(key, p, n),
|
1355
|
-
lambda x: dispatch(*x),
|
1356
|
-
(),
|
1357
|
-
lambda _: jnp.where(cond0, n, 0),
|
1358
|
-
)
|
1359
|
-
|
1360
|
-
|
1361
|
-
@partial(jit, static_argnums=(3,))
|
1362
|
-
def _binomial(key, p, n, shape):
|
1363
|
-
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
|
1364
|
-
# reshape to map over axis 0
|
1365
|
-
p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
|
1366
|
-
n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
|
1367
|
-
key = jr.split(key, jnp.size(p))
|
1368
|
-
if jax.default_backend() == "cpu":
|
1369
|
-
ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
|
1370
|
-
else:
|
1371
|
-
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
|
1372
|
-
return jnp.reshape(ret, shape)
|
1373
|
-
|
1374
|
-
|
1375
1246
|
@partial(jit, static_argnums=(2,))
|
1376
1247
|
def _categorical(key, p, shape):
|
1377
1248
|
# this implementation is fast when event shape is small, and slow otherwise
|
brainstate/surrogate.py
CHANGED
@@ -19,9 +19,13 @@ from __future__ import annotations
|
|
19
19
|
import jax
|
20
20
|
import jax.numpy as jnp
|
21
21
|
import jax.scipy as sci
|
22
|
-
from jax.core import Primitive
|
23
22
|
from jax.interpreters import batching, ad, mlir
|
24
23
|
|
24
|
+
if jax.__version_info__ < (0, 4, 38):
|
25
|
+
from jax.core import Primitive
|
26
|
+
else:
|
27
|
+
from jax.extend.core import Primitive
|
28
|
+
|
25
29
|
__all__ = [
|
26
30
|
'Surrogate',
|
27
31
|
'Sigmoid',
|
brainstate/util/__init__.py
CHANGED
@@ -27,8 +27,6 @@ from ._scaling import *
|
|
27
27
|
from ._scaling import __all__ as _mem_scale_all
|
28
28
|
from ._struct import *
|
29
29
|
from ._struct import __all__ as _struct_all
|
30
|
-
from ._visualization import *
|
31
|
-
from ._visualization import __all__ as _visualization_all
|
32
30
|
|
33
31
|
__all__ = (
|
34
32
|
_others_all
|
@@ -38,7 +36,6 @@ __all__ = (
|
|
38
36
|
+ _struct_all
|
39
37
|
+ _error_all
|
40
38
|
+ _mapping_all
|
41
|
-
+ _visualization_all
|
42
39
|
)
|
43
40
|
del (
|
44
41
|
_others_all,
|
@@ -48,5 +45,4 @@ del (
|
|
48
45
|
_struct_all,
|
49
46
|
_error_all,
|
50
47
|
_mapping_all,
|
51
|
-
_visualization_all,
|
52
48
|
)
|
brainstate/util/_caller.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_dict.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
#
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
2
5
|
#
|
3
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_filter.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_pretty_repr.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_struct.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
{brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250120
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.10
|
20
20
|
Classifier: Programming Language :: Python :: 3.11
|
21
21
|
Classifier: Programming Language :: Python :: 3.12
|
22
|
+
Classifier: Programming Language :: Python :: 3.13
|
22
23
|
Classifier: License :: OSI Approved :: Apache Software License
|
23
24
|
Classifier: Programming Language :: Python
|
24
25
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
@@ -1,46 +1,45 @@
|
|
1
1
|
brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
|
2
|
-
brainstate/_state.py,sha256=
|
3
|
-
brainstate/_state_test.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=GZ46liHZSHbAHQEuELvOeoJ27P9xiZDz06G2AASjAjA,29142
|
3
|
+
brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
5
|
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
6
6
|
brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
|
-
brainstate/surrogate.py,sha256=
|
9
|
+
brainstate/surrogate.py,sha256=t4SzVwUVMAPtC-O1vFbuE9F4265wgAv7ud77ufIJsuk,48464
|
10
10
|
brainstate/transform.py,sha256=cxbymTlJ6uHvJWEEYXzFUkAySs_TbUTHakt0NQgWJ3s,808
|
11
11
|
brainstate/typing.py,sha256=Qh-LBzm6oG4rSXv4V5qB8SNYcoOR7bASoK_iQxnlafk,10467
|
12
12
|
brainstate/augment/__init__.py,sha256=BtXIBel7GbttmfBX6grxOxl0IiOJxLEa7qCGAXumamE,1286
|
13
13
|
brainstate/augment/_autograd.py,sha256=o9ivoEY7BmtdM1XmzdMmeRXpj6Tvn5xNB8LSGp2HKC8,25238
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9ptJAGKo,44060
|
15
|
-
brainstate/augment/_eval_shape.py,sha256=
|
16
|
-
brainstate/augment/_eval_shape_test.py,sha256=
|
17
|
-
brainstate/augment/_mapping.py,sha256=
|
18
|
-
brainstate/augment/_mapping_test.py,sha256=
|
15
|
+
brainstate/augment/_eval_shape.py,sha256=ObCgsZ704kLduB1dbjJZh5nVQYEkLR5ebK74V5NV42k,3892
|
16
|
+
brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
|
17
|
+
brainstate/augment/_mapping.py,sha256=nU6Y7fSnYXyQSILXU2QT-O73Fm3pnwOmgUoDaHqjve8,21544
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=_KFhE3CXItwpbZ1gJfrDu3yUtX0YbfPUuHJG_G_BXEs,8963
|
19
19
|
brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
|
20
20
|
brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
|
21
|
-
brainstate/compile/_ad_checkpoint.py,sha256=
|
21
|
+
brainstate/compile/_ad_checkpoint.py,sha256=K6I4vnznDsqC9cUeCnez9UdV9r_toGA3zHezoHLA6mI,9377
|
22
22
|
brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
|
23
23
|
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
26
|
brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
|
27
|
-
brainstate/compile/_jit.py,sha256=
|
27
|
+
brainstate/compile/_jit.py,sha256=itAWENKfJvnlaWl_uSy8lHTK8K1in89F_ZXXwp-EGRM,13944
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzjlyk0Te53-AvOE,23492
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
|
-
brainstate/compile/_loop_no_collection.py,sha256=
|
31
|
+
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=DQf_80w3p0wi2Gb9P6_tLMJ0Oadgyr_jWkVjus0MSjw,33205
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
|
35
|
-
brainstate/compile/_progress_bar.py,sha256=
|
36
|
-
brainstate/compile/_unvmap.py,sha256=
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=0oVlZ4kW_ZMciJjOR_ebj3PNe_XkCMkoQpv-HUUdoF0,5554
|
36
|
+
brainstate/compile/_unvmap.py,sha256=EY4rbqCzzPOiaRwpWTiyBwb5dVkYFnacHhBZUZObxPI,4255
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
|
-
brainstate/event/__init__.py,sha256=
|
39
|
-
brainstate/event/_csr.py,sha256=
|
38
|
+
brainstate/event/__init__.py,sha256=gSEem-1oTHgy99Mjm3uumTXVd93tLVl0c4dUgRpoifk,895
|
39
|
+
brainstate/event/_csr.py,sha256=PYKw8CGNgQ24MxQDoeBZTrPuC7Z-GetXQld9KiTbNYw,40063
|
40
|
+
brainstate/event/_csr_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
|
40
41
|
brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,10266
|
41
|
-
brainstate/event/
|
42
|
-
brainstate/event/_csr_mv_test.py,sha256=WQfAvp_3UeCUGAZjr3_aqQvrB-eYZcFEN4v1PBe9fUQ,4012
|
43
|
-
brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
|
42
|
+
brainstate/event/_csr_test.py,sha256=_iXwUFq90GU7npVOUnlI4NA27RJ8zyCZBxe7NDH803o,9533
|
44
43
|
brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
|
45
44
|
brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
|
46
45
|
brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
|
@@ -48,7 +47,7 @@ brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV
|
|
48
47
|
brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
|
49
48
|
brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
|
50
49
|
brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
|
51
|
-
brainstate/event/_xla_custom_op.py,sha256=
|
50
|
+
brainstate/event/_xla_custom_op.py,sha256=wF_nKgLUv1IGd8OY89MYqIvyZITl8UcrVysJWFugJxY,11093
|
52
51
|
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
53
52
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
54
53
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
@@ -56,14 +55,11 @@ brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv
|
|
56
55
|
brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
|
57
56
|
brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
|
58
57
|
brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
|
59
|
-
brainstate/graph/__init__.py,sha256=
|
60
|
-
brainstate/graph/
|
61
|
-
brainstate/graph/_graph_context_test.py,sha256=IYpjqbXwSFF65XL0ZbdPeC1jYyEHLpQVrhuFeJXH4GM,2409
|
62
|
-
brainstate/graph/_graph_convert.py,sha256=llSREtGQrIggkD0wmxUbYKuSveLW4ihDZME6Ab-mRTQ,9147
|
63
|
-
brainstate/graph/_graph_node.py,sha256=mmZ0jhZev8ReNJhVLgWqYJEedEDtJHxhwxRv4ytQVNo,9268
|
58
|
+
brainstate/graph/__init__.py,sha256=fyvQMlAUY3QYTzvDzz5TDoWS2XQwZ6P3ic6BtysZyHM,1026
|
59
|
+
brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
|
64
60
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
65
|
-
brainstate/graph/_graph_operation.py,sha256=
|
66
|
-
brainstate/graph/_graph_operation_test.py,sha256=
|
61
|
+
brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
|
62
|
+
brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
|
67
63
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
68
64
|
brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
|
69
65
|
brainstate/init/_generic.py,sha256=sGOvd_atpxLWqqZKobTfAiMiYRnDC19PBNHdQy_igFM,8028
|
@@ -83,7 +79,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
83
79
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
84
80
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
85
81
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
86
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
82
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=UNoGxKIKXwPnhelljDowqAWlV6ds7aBBkEbvdy2oDI4,11302
|
87
83
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
88
84
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
89
85
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -102,10 +98,10 @@ brainstate/nn/_elementwise/_dropout_test.py,sha256=ZzNvjFf46NpKWGBIcT6O0lKOBGpxO
|
|
102
98
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
103
99
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
104
100
|
brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
|
105
|
-
brainstate/nn/_interaction/_conv.py,sha256=
|
101
|
+
brainstate/nn/_interaction/_conv.py,sha256=lwyxTVsJVPiKlZcgB6iqE64aX7AOJzplDSj4y6-m18o,18592
|
106
102
|
brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
|
107
103
|
brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
|
108
|
-
brainstate/nn/_interaction/_linear.py,sha256=
|
104
|
+
brainstate/nn/_interaction/_linear.py,sha256=EnkOk1oE79rvRIjU6HBllxUpVOEcQQCj4vtavo9AJjI,14767
|
109
105
|
brainstate/nn/_interaction/_linear_test.py,sha256=QfCR8SBBed9OnSY-AmQ0kJqoggDA3Xem0dRJ0BusxLU,2872
|
110
106
|
brainstate/nn/_interaction/_normalizations.py,sha256=7YDzkmO_iqd70fH_wawb60Bu8eGOdvZq23emP-b68Hc,37440
|
111
107
|
brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
|
@@ -119,26 +115,24 @@ brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSML
|
|
119
115
|
brainstate/optim/_optax_optimizer_test.py,sha256=DAomE8Eu3dn4gh1S3EZ_u4pW4rhcl16vWPbnDcN3Rs4,1762
|
120
116
|
brainstate/optim/_sgd_optimizer.py,sha256=NVKYhGcw2D1ksNWUIXZcj-74LUaan8XL3EERk-EHMRI,46008
|
121
117
|
brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
|
122
|
-
brainstate/random/_rand_funs.py,sha256=
|
118
|
+
brainstate/random/_rand_funs.py,sha256=WaelvEpeQb6Vuqt4eNgsAtd7GI8BqgEdVYbXgtCOd54,137682
|
123
119
|
brainstate/random/_rand_funs_test.py,sha256=abO5lSoPBgBcg6ecFE1qnCg98__QGa68GSYC5pQW5QI,19438
|
124
|
-
brainstate/random/_rand_seed.py,sha256=
|
120
|
+
brainstate/random/_rand_seed.py,sha256=MHA9znbdJW9ujx73onDRrAOI684_0FmGfqczBsSXYQg,5985
|
125
121
|
brainstate/random/_rand_seed_test.py,sha256=Qibcs-ZqCvj1LuucmQ8H00B_HBNhf2f6un0aUdNZNTw,1518
|
126
|
-
brainstate/random/_rand_state.py,sha256=
|
122
|
+
brainstate/random/_rand_state.py,sha256=nuoQ8GU1MfJPRNN-ZmRQsggVjoyPhaEdZmwM7_4-Q3c,55206
|
127
123
|
brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
|
128
|
-
brainstate/util/__init__.py,sha256
|
129
|
-
brainstate/util/_caller.py,sha256=
|
130
|
-
brainstate/util/_dict.py,sha256=
|
124
|
+
brainstate/util/__init__.py,sha256=-FWEuSKXG3mWxYphGFAy3UEuVe39lFs1GruluzdXDoI,1502
|
125
|
+
brainstate/util/_caller.py,sha256=T3bzu7-09r-6EOrU6Muca_aMXSQua_X2lXjEqb-w39w,2782
|
126
|
+
brainstate/util/_dict.py,sha256=Yapug-_RZQYjvd8cZ3v90_MX7rUYJDBzBnZJT6a0NXY,26178
|
131
127
|
brainstate/util/_dict_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
|
132
128
|
brainstate/util/_error.py,sha256=eyZ8PGFixqe2K5OEfjSDzI-2tU0ieYQoUpBP7yStlPQ,878
|
133
|
-
brainstate/util/_filter.py,sha256=
|
129
|
+
brainstate/util/_filter.py,sha256=1-bvFHdjeehvXeHTrCEp8xr25lopKe8d3XZGCNegq0s,4970
|
134
130
|
brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14762
|
135
|
-
brainstate/util/_pretty_repr.py,sha256=
|
131
|
+
brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
|
136
132
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
137
|
-
brainstate/util/_struct.py,sha256=
|
138
|
-
brainstate/
|
139
|
-
brainstate/
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.post20250105.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
144
|
-
brainstate-0.1.0.post20250105.dist-info/RECORD,,
|
133
|
+
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
134
|
+
brainstate-0.1.0.post20250120.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
135
|
+
brainstate-0.1.0.post20250120.dist-info/METADATA,sha256=vUyr4XjiyAW68waFKMray9EEFHTqjqRp5GlqAG8LsKY,3585
|
136
|
+
brainstate-0.1.0.post20250120.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
137
|
+
brainstate-0.1.0.post20250120.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
138
|
+
brainstate-0.1.0.post20250120.dist-info/RECORD,,
|