brainstate 0.1.0.post20250209__py2.py3-none-any.whl → 0.1.0.post20250211__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +21 -18
- brainstate/compile/_jit.py +0 -1
- brainstate/compile/_make_jaxpr.py +1 -1
- brainstate/compile/_util.py +1 -1
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/nn/_dyn_impl/_inputs.py +41 -3
- brainstate/nn/_interaction/_conv.py +12 -10
- {brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/RECORD +12 -12
- {brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -263,9 +263,8 @@ class State(Generic[A], PrettyReprTree):
|
|
263
263
|
"""
|
264
264
|
The data and its value.
|
265
265
|
"""
|
266
|
-
self.check_if_deleted()
|
267
266
|
record_state_value_read(self)
|
268
|
-
return self.
|
267
|
+
return self._read_value()
|
269
268
|
|
270
269
|
@value.setter
|
271
270
|
def value(self, v) -> None:
|
@@ -275,7 +274,14 @@ class State(Generic[A], PrettyReprTree):
|
|
275
274
|
Args:
|
276
275
|
v: The value.
|
277
276
|
"""
|
278
|
-
|
277
|
+
# NOTE: the following order is important
|
278
|
+
|
279
|
+
if isinstance(v, State): # value checking
|
280
|
+
raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
|
281
|
+
self._check_value_tree(v) # check the tree structure
|
282
|
+
record_state_value_write(self) # record the value by the stack (>= level)
|
283
|
+
self._been_writen = True # set the flag
|
284
|
+
self._write_value(v) # write the value
|
279
285
|
|
280
286
|
@property
|
281
287
|
def stack_level(self):
|
@@ -297,17 +303,18 @@ class State(Generic[A], PrettyReprTree):
|
|
297
303
|
"""
|
298
304
|
self._level = level
|
299
305
|
|
300
|
-
def
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
self.
|
305
|
-
|
306
|
-
|
307
|
-
|
306
|
+
def _read_value(self) -> PyTree[ArrayLike]:
|
307
|
+
"""
|
308
|
+
The interface to customize the value reading.
|
309
|
+
"""
|
310
|
+
self.check_if_deleted()
|
311
|
+
return self._value
|
312
|
+
|
313
|
+
def _write_value(self, v) -> None:
|
314
|
+
"""
|
315
|
+
The interface to customize the value writing.
|
316
|
+
"""
|
308
317
|
self._value = v
|
309
|
-
# set flag
|
310
|
-
self._been_writen = True
|
311
318
|
|
312
319
|
def restore_value(self, v) -> None:
|
313
320
|
"""
|
@@ -813,11 +820,7 @@ class TreefyState(Generic[A], PrettyReprTree):
|
|
813
820
|
return 'value', v
|
814
821
|
|
815
822
|
if k == '_name':
|
816
|
-
if
|
817
|
-
return None, None
|
818
|
-
else:
|
819
|
-
return 'name', v
|
820
|
-
|
823
|
+
return (None, None) if v is None else ('name', v)
|
821
824
|
return k, v
|
822
825
|
|
823
826
|
def replace(self, value: B) -> TreefyState[B]:
|
brainstate/compile/_jit.py
CHANGED
@@ -82,7 +82,6 @@ def _get_jitted_fun(
|
|
82
82
|
return fun.fun(*args, **params)
|
83
83
|
|
84
84
|
# compile the function and get the state trace
|
85
|
-
# print('Compiling ...')
|
86
85
|
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
87
86
|
read_state_vals = state_trace.get_read_state_values(True)
|
88
87
|
|
@@ -499,7 +499,7 @@ class StatefulFunction(object):
|
|
499
499
|
state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
|
500
500
|
for st, written, val in zip(state_trace.states, state_trace.been_writen, state_vals):
|
501
501
|
if written:
|
502
|
-
st.
|
502
|
+
st.value = val
|
503
503
|
else:
|
504
504
|
st.restore_value(val)
|
505
505
|
return out
|
brainstate/compile/_util.py
CHANGED
@@ -31,7 +31,7 @@ def write_back_state_values(
|
|
31
31
|
assert len(state_trace.states) == len(state_trace.been_writen) == len(read_state_vals) == len(write_state_vals)
|
32
32
|
for st, write, val_r, val_w in zip(state_trace.states, state_trace.been_writen, read_state_vals, write_state_vals):
|
33
33
|
if write:
|
34
|
-
st.
|
34
|
+
st.value = val_w
|
35
35
|
else:
|
36
36
|
st.restore_value(val_r)
|
37
37
|
|
@@ -609,7 +609,7 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
609
609
|
variable.update_from_ref(value)
|
610
610
|
elif isinstance(value, State):
|
611
611
|
if value._been_writen:
|
612
|
-
variable.
|
612
|
+
variable.value = value.value
|
613
613
|
else:
|
614
614
|
variable.restore_value(value.value)
|
615
615
|
else:
|
@@ -216,9 +216,37 @@ def poisson_input(
|
|
216
216
|
weight: u.Quantity,
|
217
217
|
target: State,
|
218
218
|
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
219
|
+
refractory: Optional[Union[jax.Array]] = None,
|
219
220
|
):
|
220
221
|
"""
|
221
|
-
Poisson
|
222
|
+
Generates Poisson-distributed input spikes to a target state variable.
|
223
|
+
|
224
|
+
This function simulates Poisson input to a given state, updating the target
|
225
|
+
variable with generated spikes based on the specified frequency, number of inputs,
|
226
|
+
and synaptic weight. The input can be applied to specific indices of the target
|
227
|
+
or to the entire target if indices are not provided.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
freq : u.Quantity[u.Hz]
|
232
|
+
The frequency of the Poisson input in Hertz.
|
233
|
+
num_input : int
|
234
|
+
The number of input channels or neurons generating the Poisson spikes.
|
235
|
+
weight : u.Quantity
|
236
|
+
The synaptic weight applied to each spike.
|
237
|
+
target : State
|
238
|
+
The target state variable to which the Poisson input is applied.
|
239
|
+
indices : Optional[Union[np.ndarray, jax.Array]], optional
|
240
|
+
Specific indices of the target to apply the input. If None, the input is applied
|
241
|
+
to the entire target.
|
242
|
+
refractory : Optional[Union[jax.Array]], optional
|
243
|
+
A boolean array indicating which parts of the target are in a refractory state
|
244
|
+
and should not be updated. Should be the same length as the target.
|
245
|
+
|
246
|
+
Returns
|
247
|
+
-------
|
248
|
+
None
|
249
|
+
The function updates the target state in place with the generated Poisson input.
|
222
250
|
"""
|
223
251
|
freq = maybe_state(freq)
|
224
252
|
weight = maybe_state(weight)
|
@@ -291,7 +319,7 @@ def poisson_input(
|
|
291
319
|
# )
|
292
320
|
|
293
321
|
# update target variable
|
294
|
-
|
322
|
+
data = jax.tree.map(
|
295
323
|
lambda tar, x: tar + x * weight,
|
296
324
|
target.value,
|
297
325
|
inp,
|
@@ -358,9 +386,19 @@ def poisson_input(
|
|
358
386
|
# )
|
359
387
|
|
360
388
|
# update target variable
|
361
|
-
|
389
|
+
data = jax.tree.map(
|
362
390
|
lambda x, tar: tar.at[indices].add(x * weight),
|
363
391
|
inp,
|
364
392
|
tar_val,
|
365
393
|
is_leaf=u.math.is_quantity
|
366
394
|
)
|
395
|
+
|
396
|
+
if refractory is not None:
|
397
|
+
target.value = jax.tree.map(
|
398
|
+
lambda x, tar: u.math.where(refractory, tar, x),
|
399
|
+
data,
|
400
|
+
tar_val,
|
401
|
+
is_leaf=u.math.is_quantity
|
402
|
+
)
|
403
|
+
else:
|
404
|
+
target.value = data
|
@@ -193,16 +193,18 @@ class _Conv(_BaseConv):
|
|
193
193
|
name: str = None,
|
194
194
|
param_type: type = ParamState,
|
195
195
|
):
|
196
|
-
super().__init__(
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
196
|
+
super().__init__(
|
197
|
+
in_size=in_size,
|
198
|
+
out_channels=out_channels,
|
199
|
+
kernel_size=kernel_size,
|
200
|
+
stride=stride,
|
201
|
+
padding=padding,
|
202
|
+
lhs_dilation=lhs_dilation,
|
203
|
+
rhs_dilation=rhs_dilation,
|
204
|
+
groups=groups,
|
205
|
+
w_mask=w_mask,
|
206
|
+
name=name
|
207
|
+
)
|
206
208
|
|
207
209
|
self.w_initializer = w_init
|
208
210
|
self.b_initializer = b_init
|
{brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250211
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,5 +1,5 @@
|
|
1
1
|
brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
|
2
|
-
brainstate/_state.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=TrlgaX_Hu3aSdkOUqw88HZMKUKM_PgnE9YGp595nfi0,27761
|
3
3
|
brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
5
|
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
@@ -24,17 +24,17 @@ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nb
|
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
26
|
brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
|
27
|
-
brainstate/compile/_jit.py,sha256
|
27
|
+
brainstate/compile/_jit.py,sha256=-Y8fyy8gc7qQT2ti4-N-74hjP_6C-D8YC5h-1unEKuI,13910
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
29
|
brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzjlyk0Te53-AvOE,23492
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=Rr36U0s8qow1A4KJYXkALX10Rm2pkSYF2j_1eiSuSGI,33292
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
36
36
|
brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
|
37
|
-
brainstate/compile/_util.py,sha256=
|
37
|
+
brainstate/compile/_util.py,sha256=iKk51BHAerBFj2BTxPNdjsk3MZQiXenzpCr7Ys0iYWg,6299
|
38
38
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
39
39
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
40
40
|
brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
|
@@ -44,7 +44,7 @@ brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36Qv
|
|
44
44
|
brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
|
45
45
|
brainstate/graph/_graph_node.py,sha256=XwzOuaZG9x4eZknQjzJoTnnYAy7wcKD5Vox1VkYr8GM,8345
|
46
46
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
47
|
-
brainstate/graph/_graph_operation.py,sha256=
|
47
|
+
brainstate/graph/_graph_operation.py,sha256=UtBNP7hvxa-5i99LQJStXbFhUbl3icdfTq1oF4MeH1g,64106
|
48
48
|
brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
|
49
49
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
50
50
|
brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
|
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
65
65
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
66
66
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
67
67
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
68
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
68
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=x4bcp7fo5SI5TC4TmyARngK-PE7OvGHprJ17Levs0ls,12937
|
69
69
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
70
70
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
71
71
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -84,7 +84,7 @@ brainstate/nn/_elementwise/_dropout_test.py,sha256=k6aB5v8RYMoV5w8UV9UNSFhaQTV7w
|
|
84
84
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
85
85
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
86
86
|
brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
|
87
|
-
brainstate/nn/_interaction/_conv.py,sha256=
|
87
|
+
brainstate/nn/_interaction/_conv.py,sha256=eKhABWtG3QlOy7TPY9yoQjP3liBh9bb8X5Wns3_YUUQ,18499
|
88
88
|
brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
|
89
89
|
brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
|
90
90
|
brainstate/nn/_interaction/_linear.py,sha256=EnkOk1oE79rvRIjU6HBllxUpVOEcQQCj4vtavo9AJjI,14767
|
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
117
117
|
brainstate/util/_pretty_repr.py,sha256=vNwRlj4sI4QJ_koyIs7eKdUMeB_QWwzRYsE8PpAWN3g,5833
|
118
118
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
119
119
|
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
120
|
-
brainstate-0.1.0.
|
121
|
-
brainstate-0.1.0.
|
122
|
-
brainstate-0.1.0.
|
123
|
-
brainstate-0.1.0.
|
124
|
-
brainstate-0.1.0.
|
120
|
+
brainstate-0.1.0.post20250211.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
121
|
+
brainstate-0.1.0.post20250211.dist-info/METADATA,sha256=werv_oEsECW5xLgvO__Yjth5vLfz-YARh442Q6E6FIk,3585
|
122
|
+
brainstate-0.1.0.post20250211.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
123
|
+
brainstate-0.1.0.post20250211.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
124
|
+
brainstate-0.1.0.post20250211.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250209.dist-info → brainstate-0.1.0.post20250211.dist-info}/top_level.txt
RENAMED
File without changes
|