brainstate 0.1.0.post20250208__py2.py3-none-any.whl → 0.1.0.post20250210__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +20 -13
- brainstate/compile/_jit.py +0 -1
- brainstate/compile/_make_jaxpr.py +1 -1
- brainstate/compile/_util.py +1 -1
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/nn/_dyn_impl/_inputs.py +41 -3
- brainstate/nn/_module.py +1 -2
- {brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/RECORD +12 -12
- {brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -261,9 +261,8 @@ class State(Generic[A], PrettyRepr):
|
|
261
261
|
"""
|
262
262
|
The data and its value.
|
263
263
|
"""
|
264
|
-
self.check_if_deleted()
|
265
264
|
record_state_value_read(self)
|
266
|
-
return self.
|
265
|
+
return self._read_value()
|
267
266
|
|
268
267
|
@value.setter
|
269
268
|
def value(self, v) -> None:
|
@@ -273,7 +272,14 @@ class State(Generic[A], PrettyRepr):
|
|
273
272
|
Args:
|
274
273
|
v: The value.
|
275
274
|
"""
|
276
|
-
|
275
|
+
# NOTE: the following order is important
|
276
|
+
|
277
|
+
if isinstance(v, State): # value checking
|
278
|
+
raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
|
279
|
+
self._check_value_tree(v) # check the tree structure
|
280
|
+
record_state_value_write(self) # record the value by the stack (>= level)
|
281
|
+
self._been_writen = True # set the flag
|
282
|
+
self._write_value(v) # write the value
|
277
283
|
|
278
284
|
@property
|
279
285
|
def stack_level(self):
|
@@ -295,17 +301,18 @@ class State(Generic[A], PrettyRepr):
|
|
295
301
|
"""
|
296
302
|
self._level = level
|
297
303
|
|
298
|
-
def
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
|
304
|
-
|
305
|
-
|
304
|
+
def _read_value(self) -> PyTree[ArrayLike]:
|
305
|
+
"""
|
306
|
+
The interface to customize the value reading.
|
307
|
+
"""
|
308
|
+
self.check_if_deleted()
|
309
|
+
return self._value
|
310
|
+
|
311
|
+
def _write_value(self, v) -> None:
|
312
|
+
"""
|
313
|
+
The interface to customize the value writing.
|
314
|
+
"""
|
306
315
|
self._value = v
|
307
|
-
# set flag
|
308
|
-
self._been_writen = True
|
309
316
|
|
310
317
|
def restore_value(self, v) -> None:
|
311
318
|
"""
|
brainstate/compile/_jit.py
CHANGED
@@ -82,7 +82,6 @@ def _get_jitted_fun(
|
|
82
82
|
return fun.fun(*args, **params)
|
83
83
|
|
84
84
|
# compile the function and get the state trace
|
85
|
-
# print('Compiling ...')
|
86
85
|
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
87
86
|
read_state_vals = state_trace.get_read_state_values(True)
|
88
87
|
|
@@ -499,7 +499,7 @@ class StatefulFunction(object):
|
|
499
499
|
state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
|
500
500
|
for st, written, val in zip(state_trace.states, state_trace.been_writen, state_vals):
|
501
501
|
if written:
|
502
|
-
st.
|
502
|
+
st.value = val
|
503
503
|
else:
|
504
504
|
st.restore_value(val)
|
505
505
|
return out
|
brainstate/compile/_util.py
CHANGED
@@ -31,7 +31,7 @@ def write_back_state_values(
|
|
31
31
|
assert len(state_trace.states) == len(state_trace.been_writen) == len(read_state_vals) == len(write_state_vals)
|
32
32
|
for st, write, val_r, val_w in zip(state_trace.states, state_trace.been_writen, read_state_vals, write_state_vals):
|
33
33
|
if write:
|
34
|
-
st.
|
34
|
+
st.value = val_w
|
35
35
|
else:
|
36
36
|
st.restore_value(val_r)
|
37
37
|
|
@@ -609,7 +609,7 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
609
609
|
variable.update_from_ref(value)
|
610
610
|
elif isinstance(value, State):
|
611
611
|
if value._been_writen:
|
612
|
-
variable.
|
612
|
+
variable.value = value.value
|
613
613
|
else:
|
614
614
|
variable.restore_value(value.value)
|
615
615
|
else:
|
@@ -216,9 +216,37 @@ def poisson_input(
|
|
216
216
|
weight: u.Quantity,
|
217
217
|
target: State,
|
218
218
|
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
219
|
+
refractory: Optional[Union[jax.Array]] = None,
|
219
220
|
):
|
220
221
|
"""
|
221
|
-
Poisson
|
222
|
+
Generates Poisson-distributed input spikes to a target state variable.
|
223
|
+
|
224
|
+
This function simulates Poisson input to a given state, updating the target
|
225
|
+
variable with generated spikes based on the specified frequency, number of inputs,
|
226
|
+
and synaptic weight. The input can be applied to specific indices of the target
|
227
|
+
or to the entire target if indices are not provided.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
freq : u.Quantity[u.Hz]
|
232
|
+
The frequency of the Poisson input in Hertz.
|
233
|
+
num_input : int
|
234
|
+
The number of input channels or neurons generating the Poisson spikes.
|
235
|
+
weight : u.Quantity
|
236
|
+
The synaptic weight applied to each spike.
|
237
|
+
target : State
|
238
|
+
The target state variable to which the Poisson input is applied.
|
239
|
+
indices : Optional[Union[np.ndarray, jax.Array]], optional
|
240
|
+
Specific indices of the target to apply the input. If None, the input is applied
|
241
|
+
to the entire target.
|
242
|
+
refractory : Optional[Union[jax.Array]], optional
|
243
|
+
A boolean array indicating which parts of the target are in a refractory state
|
244
|
+
and should not be updated. Should be the same length as the target.
|
245
|
+
|
246
|
+
Returns
|
247
|
+
-------
|
248
|
+
None
|
249
|
+
The function updates the target state in place with the generated Poisson input.
|
222
250
|
"""
|
223
251
|
freq = maybe_state(freq)
|
224
252
|
weight = maybe_state(weight)
|
@@ -291,7 +319,7 @@ def poisson_input(
|
|
291
319
|
# )
|
292
320
|
|
293
321
|
# update target variable
|
294
|
-
|
322
|
+
data = jax.tree.map(
|
295
323
|
lambda tar, x: tar + x * weight,
|
296
324
|
target.value,
|
297
325
|
inp,
|
@@ -358,9 +386,19 @@ def poisson_input(
|
|
358
386
|
# )
|
359
387
|
|
360
388
|
# update target variable
|
361
|
-
|
389
|
+
data = jax.tree.map(
|
362
390
|
lambda x, tar: tar.at[indices].add(x * weight),
|
363
391
|
inp,
|
364
392
|
tar_val,
|
365
393
|
is_leaf=u.math.is_quantity
|
366
394
|
)
|
395
|
+
|
396
|
+
if refractory is not None:
|
397
|
+
target.value = jax.tree.map(
|
398
|
+
lambda x, tar: u.math.where(refractory, tar, x),
|
399
|
+
data,
|
400
|
+
tar_val,
|
401
|
+
is_leaf=u.math.is_quantity
|
402
|
+
)
|
403
|
+
else:
|
404
|
+
target.value = data
|
brainstate/nn/_module.py
CHANGED
@@ -294,8 +294,7 @@ class Sequential(Module):
|
|
294
294
|
# the input and output shape
|
295
295
|
if first.in_size is not None:
|
296
296
|
self.in_size = first.in_size
|
297
|
-
|
298
|
-
self.out_size = tuple(in_size)
|
297
|
+
self.out_size = tuple(in_size)
|
299
298
|
|
300
299
|
def update(self, x):
|
301
300
|
"""Update function of a sequential model.
|
{brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250210
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,5 +1,5 @@
|
|
1
1
|
brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
|
2
|
-
brainstate/_state.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=Ol-FqHWQnIKmylXHjdsY5izKQhIb0bUw3_UL-7zj4WA,29447
|
3
3
|
brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
5
|
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
@@ -24,17 +24,17 @@ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nb
|
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
26
|
brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
|
27
|
-
brainstate/compile/_jit.py,sha256
|
27
|
+
brainstate/compile/_jit.py,sha256=-Y8fyy8gc7qQT2ti4-N-74hjP_6C-D8YC5h-1unEKuI,13910
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
29
|
brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzjlyk0Te53-AvOE,23492
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=Q-nwm-ibBN0ube4ZjATp924pUkrXuaeT0XgSstqkI40,33304
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
36
36
|
brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
|
37
|
-
brainstate/compile/_util.py,sha256=
|
37
|
+
brainstate/compile/_util.py,sha256=iKk51BHAerBFj2BTxPNdjsk3MZQiXenzpCr7Ys0iYWg,6299
|
38
38
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
39
39
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
40
40
|
brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
|
@@ -44,7 +44,7 @@ brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36Qv
|
|
44
44
|
brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
|
45
45
|
brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
|
46
46
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
47
|
-
brainstate/graph/_graph_operation.py,sha256=
|
47
|
+
brainstate/graph/_graph_operation.py,sha256=UtBNP7hvxa-5i99LQJStXbFhUbl3icdfTq1oF4MeH1g,64106
|
48
48
|
brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
|
49
49
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
50
50
|
brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
|
@@ -57,7 +57,7 @@ brainstate/nn/__init__.py,sha256=rxURT8J1XfBn3Vh3Dx_WzVADWn9zVriIty5KZEG-x6o,162
|
|
57
57
|
brainstate/nn/_collective_ops.py,sha256=sSjIIs1MvZA30XFFmK7iL1D_sCeh7hFd3PanCH6kgZo,6779
|
58
58
|
brainstate/nn/_exp_euler.py,sha256=yjkfSllFxGWKEAlHo5AzBizzkFj6FEVDKmFV6E2g214,3521
|
59
59
|
brainstate/nn/_exp_euler_test.py,sha256=clwRD8QR71k1jn6NrACMDEUcFMh0J9RTosoPnlYWUkw,1242
|
60
|
-
brainstate/nn/_module.py,sha256=
|
60
|
+
brainstate/nn/_module.py,sha256=HDLPvLfB7jat2VT3gBu0MxA7vfzK7xgowemitHX8Cgo,10835
|
61
61
|
brainstate/nn/_module_test.py,sha256=V4ZhiY_zYPvArkB2eeOTtZcgQrtlRyXKMbS1AJH4vC8,8893
|
62
62
|
brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,14738
|
63
63
|
brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
|
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
65
65
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
66
66
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
67
67
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
68
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
68
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=x4bcp7fo5SI5TC4TmyARngK-PE7OvGHprJ17Levs0ls,12937
|
69
69
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
70
70
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
71
71
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
117
117
|
brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
|
118
118
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
119
119
|
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
120
|
-
brainstate-0.1.0.
|
121
|
-
brainstate-0.1.0.
|
122
|
-
brainstate-0.1.0.
|
123
|
-
brainstate-0.1.0.
|
124
|
-
brainstate-0.1.0.
|
120
|
+
brainstate-0.1.0.post20250210.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
121
|
+
brainstate-0.1.0.post20250210.dist-info/METADATA,sha256=__N9QGz8FFW6rXXG3_Y5YTKFd9iWM_MVddxJP74hZcI,3585
|
122
|
+
brainstate-0.1.0.post20250210.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
123
|
+
brainstate-0.1.0.post20250210.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
124
|
+
brainstate-0.1.0.post20250210.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250208.dist-info → brainstate-0.1.0.post20250210.dist-info}/top_level.txt
RENAMED
File without changes
|