brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_jit.py +37 -28
  9. brainstate/compile/_loop_collect_return.py +6 -3
  10. brainstate/compile/_loop_no_collection.py +2 -0
  11. brainstate/compile/_make_jaxpr.py +7 -3
  12. brainstate/compile/_progress_bar.py +68 -40
  13. brainstate/compile/_unvmap.py +6 -3
  14. brainstate/event/__init__.py +0 -2
  15. brainstate/event/_csr.py +266 -23
  16. brainstate/event/_csr_test.py +187 -0
  17. brainstate/event/_xla_custom_op.py +7 -3
  18. brainstate/graph/__init__.py +8 -12
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_interaction/_conv.py +4 -2
  24. brainstate/nn/_interaction/_linear.py +84 -10
  25. brainstate/random/_rand_funs.py +9 -2
  26. brainstate/random/_rand_seed.py +12 -2
  27. brainstate/random/_rand_state.py +50 -179
  28. brainstate/surrogate.py +5 -1
  29. brainstate/util/__init__.py +0 -4
  30. brainstate/util/_caller.py +1 -1
  31. brainstate/util/_dict.py +4 -1
  32. brainstate/util/_filter.py +1 -1
  33. brainstate/util/_pretty_repr.py +1 -1
  34. brainstate/util/_struct.py +1 -1
  35. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
  37. brainstate/event/_csr_mv_test.py +0 -118
  38. brainstate/graph/_graph_context.py +0 -443
  39. brainstate/graph/_graph_context_test.py +0 -65
  40. brainstate/graph/_graph_convert.py +0 -246
  41. brainstate/util/_tracers.py +0 -68
  42. brainstate/util/_visualization.py +0 -47
  43. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  44. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  46. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,8 @@ import numpy as np
22
22
 
23
23
  from brainstate import environ, init, random
24
24
  from brainstate._state import ShortTermState
25
- from brainstate._state import State
26
- from brainstate.compile import while_loop, cond
25
+ from brainstate._state import State, maybe_state
26
+ from brainstate.compile import while_loop
27
27
  from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
28
28
  from brainstate.nn._module import Module
29
29
  from brainstate.typing import ArrayLike, Size, DTypeLike
@@ -198,55 +198,97 @@ class PoissonInput(Module):
198
198
  self.weight = weight
199
199
 
200
200
  def update(self):
201
- p = self.freq * environ.get_dt()
202
- a = self.num_input * p
203
- b = self.num_input * (1 - p)
204
-
205
- target = self.target()
206
201
  target_state = getattr(self.target.module, self.target.item)
207
202
 
208
203
  # generate Poisson input
209
- inp = cond(
210
- u.math.logical_and(a > 5, b > 5),
211
- lambda: random.normal(a, b * p, self.indices.shape),
212
- lambda: random.binomial(self.num_input, p, self.indices.shape).astype(float)
204
+ poisson_input(
205
+ self.freq,
206
+ self.num_input,
207
+ self.weight,
208
+ target_state,
209
+ self.indices,
213
210
  )
214
211
 
215
- # update target variable
216
- target_state.value = target.at[self.indices].add(inp * self.weight)
217
-
218
212
 
219
213
  def poisson_input(
220
- freq: ArrayLike,
214
+ freq: u.Quantity[u.Hz],
221
215
  num_input: int,
222
- weight: ArrayLike,
216
+ weight: u.Quantity,
223
217
  target: State,
224
218
  indices: Optional[Union[np.ndarray, jax.Array]] = None,
225
219
  ):
226
220
  """
227
221
  Poisson Input to the given :py:class:`brainstate.State`.
228
222
  """
223
+ freq = maybe_state(freq)
224
+ weight = maybe_state(weight)
225
+
229
226
  assert isinstance(target, State), 'The target must be a State.'
230
- p = freq * environ.get_dt()
227
+ p = (freq * environ.get_dt()).to_decimal()
231
228
  a = num_input * p
232
229
  b = num_input * (1 - p)
233
230
  tar_val = target.value
231
+ cond = u.math.logical_and(a > 5, b > 5)
232
+
234
233
  if indices is None:
235
234
  # generate Poisson input
236
- inp = cond(
237
- u.math.logical_and(a > 5, b > 5),
238
- lambda: jax.tree.map(
239
- lambda tar: random.normal(a, b * p, tar.shape),
240
- tar_val,
241
- is_leaf=u.math.is_quantity
235
+ branch1 = jax.tree.map(
236
+ lambda tar: random.normal(
237
+ a,
238
+ b * p,
239
+ tar.shape,
240
+ dtype=tar.dtype
242
241
  ),
243
- lambda: jax.tree.map(
244
- lambda tar: random.binomial(num_input, p, tar.shape).astype(float),
245
- tar_val,
246
- is_leaf=u.math.is_quantity
247
- )
242
+ tar_val,
243
+ is_leaf=u.math.is_quantity
244
+ )
245
+ branch2 = jax.tree.map(
246
+ lambda tar: random.binomial(
247
+ num_input,
248
+ p,
249
+ tar.shape,
250
+ check_valid=False,
251
+ dtype=tar.dtype
252
+ ),
253
+ tar_val,
254
+ is_leaf=u.math.is_quantity,
255
+ )
256
+
257
+ inp = jax.tree.map(
258
+ lambda b1, b2: u.math.where(cond, b1, b2),
259
+ branch1,
260
+ branch2,
261
+ is_leaf=u.math.is_quantity,
248
262
  )
249
263
 
264
+ # inp = jax.lax.cond(
265
+ # cond,
266
+ # lambda rand_key: jax.tree.map(
267
+ # lambda tar: random.normal(
268
+ # a,
269
+ # b * p,
270
+ # tar.shape,
271
+ # key=rand_key,
272
+ # dtype=tar.dtype
273
+ # ),
274
+ # tar_val,
275
+ # is_leaf=u.math.is_quantity
276
+ # ),
277
+ # lambda rand_key: jax.tree.map(
278
+ # lambda tar: random.binomial(
279
+ # num_input,
280
+ # p,
281
+ # tar.shape,
282
+ # key=rand_key,
283
+ # check_valid=False,
284
+ # dtype=tar.dtype
285
+ # ),
286
+ # tar_val,
287
+ # is_leaf=u.math.is_quantity,
288
+ # ),
289
+ # random.split_key()
290
+ # )
291
+
250
292
  # update target variable
251
293
  target.value = jax.tree.map(
252
294
  lambda x: x * weight,
@@ -256,19 +298,62 @@ def poisson_input(
256
298
 
257
299
  else:
258
300
  # generate Poisson input
259
- inp = cond(
260
- u.math.logical_and(a > 5, b > 5),
261
- lambda: jax.tree.map(
262
- lambda tar: random.normal(a, b * p, tar[indices].shape),
263
- tar_val,
264
- is_leaf=u.math.is_quantity
301
+ branch1 = jax.tree.map(
302
+ lambda tar: random.normal(
303
+ a,
304
+ b * p,
305
+ tar[indices].shape,
306
+ dtype=tar.dtype
265
307
  ),
266
- lambda: jax.tree.map(
267
- lambda tar: random.binomial(num_input, p, tar[indices].shape).astype(float),
268
- tar_val,
269
- is_leaf=u.math.is_quantity
270
- )
308
+ tar_val,
309
+ is_leaf=u.math.is_quantity
271
310
  )
311
+ branch2 = jax.tree.map(
312
+ lambda tar: random.binomial(
313
+ num_input,
314
+ p,
315
+ tar[indices].shape,
316
+ # check_valid=False,
317
+ dtype=tar.dtype
318
+ ),
319
+ tar_val,
320
+ is_leaf=u.math.is_quantity
321
+ )
322
+
323
+ inp = jax.tree.map(
324
+ lambda b1, b2: u.math.where(cond, b1, b2),
325
+ branch1,
326
+ branch2,
327
+ is_leaf=u.math.is_quantity,
328
+ )
329
+
330
+ # inp = jax.lax.cond(
331
+ # cond,
332
+ # lambda rand_key: jax.tree.map(
333
+ # lambda tar: random.normal(
334
+ # a,
335
+ # b * p,
336
+ # tar[indices].shape,
337
+ # key=rand_key,
338
+ # dtype=tar.dtype
339
+ # ),
340
+ # tar_val,
341
+ # is_leaf=u.math.is_quantity
342
+ # ),
343
+ # lambda rand_key: jax.tree.map(
344
+ # lambda tar: random.binomial(
345
+ # num_input,
346
+ # p,
347
+ # tar[indices].shape,
348
+ # key=rand_key,
349
+ # check_valid=False,
350
+ # dtype=tar.dtype
351
+ # ),
352
+ # tar_val,
353
+ # is_leaf=u.math.is_quantity
354
+ # ),
355
+ # random.split_key()
356
+ # )
272
357
 
273
358
  # update target variable
274
359
  target.value = jax.tree.map(
@@ -191,6 +191,7 @@ class _Conv(_BaseConv):
191
191
  b_init: Optional[Union[Callable, ArrayLike]] = None,
192
192
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
193
193
  name: str = None,
194
+ param_type: type = ParamState,
194
195
  ):
195
196
  super().__init__(in_size=in_size,
196
197
  out_channels=out_channels,
@@ -215,7 +216,7 @@ class _Conv(_BaseConv):
215
216
  params['bias'] = bias
216
217
 
217
218
  # The weight operation
218
- self.weight = ParamState(params)
219
+ self.weight = param_type(params)
219
220
 
220
221
  # Evaluate the output shape
221
222
  abstract_y = jax.eval_shape(
@@ -346,6 +347,7 @@ class _ScaledWSConv(_BaseConv):
346
347
  b_init: Optional[Union[Callable, ArrayLike]] = None,
347
348
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
348
349
  name: str = None,
350
+ param_type: type = ParamState,
349
351
  ):
350
352
  super().__init__(in_size=in_size,
351
353
  out_channels=out_channels,
@@ -379,7 +381,7 @@ class _ScaledWSConv(_BaseConv):
379
381
  self.eps = eps
380
382
 
381
383
  # The weight operation
382
- self.weight = ParamState(params)
384
+ self.weight = param_type(params)
383
385
 
384
386
  # Evaluate the output shape
385
387
  abstract_y = jax.eval_shape(
@@ -34,6 +34,7 @@ __all__ = [
34
34
  'SparseLinear',
35
35
  'AllToAll',
36
36
  'OneToOne',
37
+ 'LoRA',
37
38
  ]
38
39
 
39
40
 
@@ -51,6 +52,7 @@ class Linear(Module):
51
52
  b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
52
53
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
53
54
  name: Optional[str] = None,
55
+ param_type: type = ParamState,
54
56
  ):
55
57
  super().__init__(name=name)
56
58
 
@@ -67,7 +69,7 @@ class Linear(Module):
67
69
  params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
68
70
  if b_init is not None:
69
71
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
70
- self.weight = ParamState(params)
72
+ self.weight = param_type(params)
71
73
 
72
74
  def update(self, x):
73
75
  params = self.weight.value
@@ -93,7 +95,7 @@ class SignedWLinear(Module):
93
95
  w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
94
96
  w_sign: Optional[ArrayLike] = None,
95
97
  name: Optional[str] = None,
96
-
98
+ param_type: type = ParamState,
97
99
  ):
98
100
  super().__init__(name=name)
99
101
 
@@ -108,7 +110,7 @@ class SignedWLinear(Module):
108
110
 
109
111
  # weights
110
112
  weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
111
- self.weight = ParamState(weight)
113
+ self.weight = param_type(weight)
112
114
 
113
115
  def update(self, x):
114
116
  w = self.weight.value
@@ -156,6 +158,7 @@ class ScaledWSLinear(Module):
156
158
  ws_gain: bool = True,
157
159
  eps: float = 1e-4,
158
160
  name: str = None,
161
+ param_type: type = ParamState,
159
162
  ):
160
163
  super().__init__(name=name)
161
164
 
@@ -179,7 +182,7 @@ class ScaledWSLinear(Module):
179
182
  if ws_gain:
180
183
  s = params['weight'].shape
181
184
  params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
182
- self.weight = ParamState(params)
185
+ self.weight = param_type(params)
183
186
 
184
187
  def update(self, x):
185
188
  params = self.weight.value
@@ -211,6 +214,7 @@ class SparseLinear(Module):
211
214
  b_init: Optional[Union[Callable, ArrayLike]] = None,
212
215
  in_size: Size = None,
213
216
  name: Optional[str] = None,
217
+ param_type: type = ParamState,
214
218
  ):
215
219
  super().__init__(name=name)
216
220
 
@@ -230,7 +234,7 @@ class SparseLinear(Module):
230
234
  params = dict(weight=spar_mat.data)
231
235
  if b_init is not None:
232
236
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
233
- self.weight = ParamState(params)
237
+ self.weight = param_type(params)
234
238
 
235
239
  def update(self, x):
236
240
  data = self.weight.value['weight']
@@ -260,6 +264,7 @@ class AllToAll(Module):
260
264
  b_init: Optional[Union[Callable, ArrayLike]] = None,
261
265
  include_self: bool = True,
262
266
  name: Optional[str] = None,
267
+ param_type: type = ParamState,
263
268
  ):
264
269
  super().__init__(name=name)
265
270
 
@@ -277,7 +282,7 @@ class AllToAll(Module):
277
282
  params = dict(weight=weight)
278
283
  if b_init is not None:
279
284
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
280
- self.weight = ParamState(params)
285
+ self.weight = param_type(params)
281
286
 
282
287
  def update(self, pre_val):
283
288
  params = self.weight.value
@@ -332,6 +337,7 @@ class OneToOne(Module):
332
337
  w_init: Union[Callable, ArrayLike] = init.Normal(),
333
338
  b_init: Optional[Union[Callable, ArrayLike]] = None,
334
339
  name: Optional[str] = None,
340
+ param_type: type = ParamState,
335
341
  ):
336
342
  super().__init__(name=name)
337
343
 
@@ -343,13 +349,81 @@ class OneToOne(Module):
343
349
  param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
344
350
  if b_init is not None:
345
351
  param['bias'] = init.param(b_init, self.out_size, allow_none=False)
346
- self.weight = param
352
+ self.weight = param_type(param)
347
353
 
348
354
  def update(self, pre_val):
349
355
  pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
350
- w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
356
+ w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
351
357
  post_val = pre_val * w_val
352
358
  post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
353
- if 'bias' in self.weight:
354
- post_val = post_val + self.weight['bias']
359
+ if 'bias' in self.weight.value:
360
+ post_val = post_val + self.weight.value['bias']
355
361
  return post_val
362
+
363
+
364
+ class LoRA(Module):
365
+ """A standalone LoRA layer.
366
+
367
+ Example usage::
368
+
369
+ >>> import brainstate as bst
370
+ >>> import jax, jax.numpy as jnp
371
+ >>> layer = bst.nn.LoRA(3, 2, 4)
372
+ >>> layer.weight.value
373
+ {'lora_a': Array([[ 0.25141352, -0.09826107],
374
+ [ 0.2328382 , 0.38869813],
375
+ [ 0.27069277, 0.7678282 ]], dtype=float32),
376
+ 'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
377
+ [ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
378
+ >>> # Wrap around existing layer
379
+ >>> linear = bst.nn.Linear(3, 4)
380
+ >>> wrapper = bst.nn.LoRA(3, 2, 4, base_module=linear)
381
+ >>> assert wrapper.base_module == linear
382
+ >>> y = layer(jnp.ones((16, 3)))
383
+ >>> y.shape
384
+ (16, 4)
385
+
386
+ Args:
387
+ in_features: the number of input features.
388
+ lora_rank: the rank of the LoRA dimension.
389
+ out_features: the number of output features.
390
+ base_module: a base module to call and substitute, if possible.
391
+ kernel_init: initializer function for the weight matrices.
392
+ param_type: the type of the LoRA params.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ in_features: int,
398
+ lora_rank: int,
399
+ out_features: int,
400
+ *,
401
+ base_module: Optional[Module] = None,
402
+ kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
403
+ param_type: type = ParamState,
404
+ ):
405
+ super().__init__()
406
+
407
+ # input and output shape
408
+ self.in_size = in_features
409
+ self.out_size = out_features
410
+ self.in_features = in_features
411
+ self.out_features = out_features
412
+
413
+ # others
414
+ self.base_module = base_module
415
+
416
+ # weights
417
+ param = dict(
418
+ lora_a=kernel_init((in_features, lora_rank)),
419
+ lora_b=kernel_init((lora_rank, out_features))
420
+ )
421
+ self.weight = param_type(param)
422
+
423
+ def __call__(self, x: ArrayLike):
424
+ out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
425
+ if self.base_module is not None:
426
+ if not callable(self.base_module):
427
+ raise ValueError('`self.base_module` must be callable.')
428
+ out += self.base_module(x)
429
+ return out
@@ -1848,7 +1848,14 @@ def lognormal(mean=None, sigma=None, size: Optional[Size] = None,
1848
1848
  return DEFAULT.lognormal(mean, sigma, size, key=key, dtype=dtype)
1849
1849
 
1850
1850
 
1851
- def binomial(n, p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1851
+ def binomial(
1852
+ n,
1853
+ p,
1854
+ size: Optional[Size] = None,
1855
+ key: Optional[SeedOrKey] = None,
1856
+ dtype: DTypeLike = None,
1857
+ check_valid: bool = True,
1858
+ ):
1852
1859
  r"""
1853
1860
  Draw samples from a binomial distribution.
1854
1861
 
@@ -1933,7 +1940,7 @@ def binomial(n, p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
1933
1940
  >>> sum(brainstate.random.binomial(9, 0.1, 20000) == 0)/20000.
1934
1941
  # answer = 0.38885, or 38%.
1935
1942
  """
1936
- return DEFAULT.binomial(n, p, size, key=key, dtype=dtype)
1943
+ return DEFAULT.binomial(n, p, size, key=key, dtype=dtype, check_valid=check_valid)
1937
1944
 
1938
1945
 
1939
1946
  def chisquare(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
@@ -21,7 +21,7 @@ import jax
21
21
  import numpy as np
22
22
 
23
23
  from brainstate.typing import SeedOrKey
24
- from ._rand_state import RandomState, DEFAULT
24
+ from ._rand_state import RandomState, DEFAULT, use_prng_key
25
25
 
26
26
  __all__ = [
27
27
  'seed', 'set_key', 'get_key', 'default_rng', 'split_key', 'split_keys', 'seed_context', 'restore_key',
@@ -123,7 +123,17 @@ def set_key(seed_or_key: SeedOrKey):
123
123
  seed_or_key: int
124
124
  The random key.
125
125
  """
126
- DEFAULT.set_key(jax.random.PRNGKey(seed_or_key) if jax.numpy.shape(seed_or_key) == () else seed_or_key)
126
+ if isinstance(seed_or_key, int):
127
+ # key = jax.random.key(seed_or_key)
128
+ key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jrjax.random.key(seed_or_key)
129
+ elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
130
+ if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
131
+ key = seed_or_key
132
+ elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
133
+ key = seed_or_key
134
+ else:
135
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
136
+ DEFAULT.set_key(key)
127
137
 
128
138
 
129
139
  def get_key():