brainstate 0.1.0.post20250209__py2.py3-none-any.whl → 0.1.0.post20250211__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_state.py CHANGED
@@ -263,9 +263,8 @@ class State(Generic[A], PrettyReprTree):
263
263
  """
264
264
  The data and its value.
265
265
  """
266
- self.check_if_deleted()
267
266
  record_state_value_read(self)
268
- return self._value
267
+ return self._read_value()
269
268
 
270
269
  @value.setter
271
270
  def value(self, v) -> None:
@@ -275,7 +274,14 @@ class State(Generic[A], PrettyReprTree):
275
274
  Args:
276
275
  v: The value.
277
276
  """
278
- self.write_value(v)
277
+ # NOTE: the following order is important
278
+
279
+ if isinstance(v, State): # value checking
280
+ raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
281
+ self._check_value_tree(v) # check the tree structure
282
+ record_state_value_write(self) # record the value by the stack (>= level)
283
+ self._been_writen = True # set the flag
284
+ self._write_value(v) # write the value
279
285
 
280
286
  @property
281
287
  def stack_level(self):
@@ -297,17 +303,18 @@ class State(Generic[A], PrettyReprTree):
297
303
  """
298
304
  self._level = level
299
305
 
300
- def write_value(self, v) -> None:
301
- # value checking
302
- if isinstance(v, State):
303
- raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
304
- self._check_value_tree(v)
305
- # write the value by the stack (>= level)
306
- record_state_value_write(self)
307
- # set the value
306
+ def _read_value(self) -> PyTree[ArrayLike]:
307
+ """
308
+ The interface to customize the value reading.
309
+ """
310
+ self.check_if_deleted()
311
+ return self._value
312
+
313
+ def _write_value(self, v) -> None:
314
+ """
315
+ The interface to customize the value writing.
316
+ """
308
317
  self._value = v
309
- # set flag
310
- self._been_writen = True
311
318
 
312
319
  def restore_value(self, v) -> None:
313
320
  """
@@ -813,11 +820,7 @@ class TreefyState(Generic[A], PrettyReprTree):
813
820
  return 'value', v
814
821
 
815
822
  if k == '_name':
816
- if self.name is None:
817
- return None, None
818
- else:
819
- return 'name', v
820
-
823
+ return (None, None) if v is None else ('name', v)
821
824
  return k, v
822
825
 
823
826
  def replace(self, value: B) -> TreefyState[B]:
@@ -82,7 +82,6 @@ def _get_jitted_fun(
82
82
  return fun.fun(*args, **params)
83
83
 
84
84
  # compile the function and get the state trace
85
- # print('Compiling ...')
86
85
  state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
87
86
  read_state_vals = state_trace.get_read_state_values(True)
88
87
 
@@ -499,7 +499,7 @@ class StatefulFunction(object):
499
499
  state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
500
500
  for st, written, val in zip(state_trace.states, state_trace.been_writen, state_vals):
501
501
  if written:
502
- st.write_value(val)
502
+ st.value = val
503
503
  else:
504
504
  st.restore_value(val)
505
505
  return out
@@ -31,7 +31,7 @@ def write_back_state_values(
31
31
  assert len(state_trace.states) == len(state_trace.been_writen) == len(read_state_vals) == len(write_state_vals)
32
32
  for st, write, val_r, val_w in zip(state_trace.states, state_trace.been_writen, read_state_vals, write_state_vals):
33
33
  if write:
34
- st.write_value(val_w)
34
+ st.value = val_w
35
35
  else:
36
36
  st.restore_value(val_r)
37
37
 
@@ -609,7 +609,7 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
609
609
  variable.update_from_ref(value)
610
610
  elif isinstance(value, State):
611
611
  if value._been_writen:
612
- variable.write_value(value.value)
612
+ variable.value = value.value
613
613
  else:
614
614
  variable.restore_value(value.value)
615
615
  else:
@@ -216,9 +216,37 @@ def poisson_input(
216
216
  weight: u.Quantity,
217
217
  target: State,
218
218
  indices: Optional[Union[np.ndarray, jax.Array]] = None,
219
+ refractory: Optional[Union[jax.Array]] = None,
219
220
  ):
220
221
  """
221
- Poisson Input to the given :py:class:`brainstate.State`.
222
+ Generates Poisson-distributed input spikes to a target state variable.
223
+
224
+ This function simulates Poisson input to a given state, updating the target
225
+ variable with generated spikes based on the specified frequency, number of inputs,
226
+ and synaptic weight. The input can be applied to specific indices of the target
227
+ or to the entire target if indices are not provided.
228
+
229
+ Parameters
230
+ ----------
231
+ freq : u.Quantity[u.Hz]
232
+ The frequency of the Poisson input in Hertz.
233
+ num_input : int
234
+ The number of input channels or neurons generating the Poisson spikes.
235
+ weight : u.Quantity
236
+ The synaptic weight applied to each spike.
237
+ target : State
238
+ The target state variable to which the Poisson input is applied.
239
+ indices : Optional[Union[np.ndarray, jax.Array]], optional
240
+ Specific indices of the target to apply the input. If None, the input is applied
241
+ to the entire target.
242
+ refractory : Optional[Union[jax.Array]], optional
243
+ A boolean array indicating which parts of the target are in a refractory state
244
+ and should not be updated. Should be the same length as the target.
245
+
246
+ Returns
247
+ -------
248
+ None
249
+ The function updates the target state in place with the generated Poisson input.
222
250
  """
223
251
  freq = maybe_state(freq)
224
252
  weight = maybe_state(weight)
@@ -291,7 +319,7 @@ def poisson_input(
291
319
  # )
292
320
 
293
321
  # update target variable
294
- target.value = jax.tree.map(
322
+ data = jax.tree.map(
295
323
  lambda tar, x: tar + x * weight,
296
324
  target.value,
297
325
  inp,
@@ -358,9 +386,19 @@ def poisson_input(
358
386
  # )
359
387
 
360
388
  # update target variable
361
- target.value = jax.tree.map(
389
+ data = jax.tree.map(
362
390
  lambda x, tar: tar.at[indices].add(x * weight),
363
391
  inp,
364
392
  tar_val,
365
393
  is_leaf=u.math.is_quantity
366
394
  )
395
+
396
+ if refractory is not None:
397
+ target.value = jax.tree.map(
398
+ lambda x, tar: u.math.where(refractory, tar, x),
399
+ data,
400
+ tar_val,
401
+ is_leaf=u.math.is_quantity
402
+ )
403
+ else:
404
+ target.value = data
@@ -193,16 +193,18 @@ class _Conv(_BaseConv):
193
193
  name: str = None,
194
194
  param_type: type = ParamState,
195
195
  ):
196
- super().__init__(in_size=in_size,
197
- out_channels=out_channels,
198
- kernel_size=kernel_size,
199
- stride=stride,
200
- padding=padding,
201
- lhs_dilation=lhs_dilation,
202
- rhs_dilation=rhs_dilation,
203
- groups=groups,
204
- w_mask=w_mask,
205
- name=name)
196
+ super().__init__(
197
+ in_size=in_size,
198
+ out_channels=out_channels,
199
+ kernel_size=kernel_size,
200
+ stride=stride,
201
+ padding=padding,
202
+ lhs_dilation=lhs_dilation,
203
+ rhs_dilation=rhs_dilation,
204
+ groups=groups,
205
+ w_mask=w_mask,
206
+ name=name
207
+ )
206
208
 
207
209
  self.w_initializer = w_init
208
210
  self.b_initializer = b_init
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250209
3
+ Version: 0.1.0.post20250211
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,5 +1,5 @@
1
1
  brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
2
- brainstate/_state.py,sha256=W1Q_RAL01rUSLZuOARMuX9I-26tBuIl_VzNWAziz6A8,27518
2
+ brainstate/_state.py,sha256=TrlgaX_Hu3aSdkOUqw88HZMKUKM_PgnE9YGp595nfi0,27761
3
3
  brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
5
  brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
@@ -24,17 +24,17 @@ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nb
24
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
25
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
26
26
  brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
27
- brainstate/compile/_jit.py,sha256=itAWENKfJvnlaWl_uSy8lHTK8K1in89F_ZXXwp-EGRM,13944
27
+ brainstate/compile/_jit.py,sha256=-Y8fyy8gc7qQT2ti4-N-74hjP_6C-D8YC5h-1unEKuI,13910
28
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
29
29
  brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzjlyk0Te53-AvOE,23492
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=MuAa9LjXi29DjYgDUrK0WaomkjbhHZk9mWW04XGcV98,33297
33
+ brainstate/compile/_make_jaxpr.py,sha256=Rr36U0s8qow1A4KJYXkALX10Rm2pkSYF2j_1eiSuSGI,33292
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
35
35
  brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
36
36
  brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
37
- brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
37
+ brainstate/compile/_util.py,sha256=iKk51BHAerBFj2BTxPNdjsk3MZQiXenzpCr7Ys0iYWg,6299
38
38
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
39
39
  brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
40
40
  brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
@@ -44,7 +44,7 @@ brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36Qv
44
44
  brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
45
45
  brainstate/graph/_graph_node.py,sha256=XwzOuaZG9x4eZknQjzJoTnnYAy7wcKD5Vox1VkYr8GM,8345
46
46
  brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
47
- brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
47
+ brainstate/graph/_graph_operation.py,sha256=UtBNP7hvxa-5i99LQJStXbFhUbl3icdfTq1oF4MeH1g,64106
48
48
  brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
49
49
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
50
50
  brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
65
65
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
66
66
  brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
67
67
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
68
- brainstate/nn/_dyn_impl/_inputs.py,sha256=QOUpAb2YJOE78uAvIS8Ep6MFcQHV-V6uRwmYvk5p9bk,11385
68
+ brainstate/nn/_dyn_impl/_inputs.py,sha256=x4bcp7fo5SI5TC4TmyARngK-PE7OvGHprJ17Levs0ls,12937
69
69
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
70
70
  brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
71
71
  brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
@@ -84,7 +84,7 @@ brainstate/nn/_elementwise/_dropout_test.py,sha256=k6aB5v8RYMoV5w8UV9UNSFhaQTV7w
84
84
  brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
85
85
  brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
86
86
  brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
87
- brainstate/nn/_interaction/_conv.py,sha256=lwyxTVsJVPiKlZcgB6iqE64aX7AOJzplDSj4y6-m18o,18592
87
+ brainstate/nn/_interaction/_conv.py,sha256=eKhABWtG3QlOy7TPY9yoQjP3liBh9bb8X5Wns3_YUUQ,18499
88
88
  brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
89
89
  brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
90
90
  brainstate/nn/_interaction/_linear.py,sha256=EnkOk1oE79rvRIjU6HBllxUpVOEcQQCj4vtavo9AJjI,14767
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
117
117
  brainstate/util/_pretty_repr.py,sha256=vNwRlj4sI4QJ_koyIs7eKdUMeB_QWwzRYsE8PpAWN3g,5833
118
118
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
119
119
  brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
120
- brainstate-0.1.0.post20250209.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
- brainstate-0.1.0.post20250209.dist-info/METADATA,sha256=vc9kKmrq5JM9Os6brL4zecy55nEpd9ASK9GZNJBQV9g,3585
122
- brainstate-0.1.0.post20250209.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
- brainstate-0.1.0.post20250209.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
- brainstate-0.1.0.post20250209.dist-info/RECORD,,
120
+ brainstate-0.1.0.post20250211.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
+ brainstate-0.1.0.post20250211.dist-info/METADATA,sha256=werv_oEsECW5xLgvO__Yjth5vLfz-YARh442Q6E6FIk,3585
122
+ brainstate-0.1.0.post20250211.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
+ brainstate-0.1.0.post20250211.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
+ brainstate-0.1.0.post20250211.dist-info/RECORD,,