brainstate 0.1.4__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.4"
20
+ __version__ = "0.1.5"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
brainstate/_state.py CHANGED
@@ -50,6 +50,7 @@ __all__ = [
50
50
  'LongTermState',
51
51
  'HiddenState',
52
52
  'ParamState',
53
+ 'BatchState',
53
54
  'TreefyState',
54
55
  'FakeState',
55
56
 
@@ -185,10 +185,10 @@ def _compile_stateful_function(
185
185
  if isinstance(in_axes, int):
186
186
  args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
187
187
  elif isinstance(in_axes, tuple):
188
- args = tuple(
189
- [arg if in_axis is None else _remove_axis(arg, in_axis)
190
- for arg, in_axis in zip(args, in_axes)]
191
- )
188
+ args = tuple([
189
+ arg if in_axis is None else _remove_axis(arg, in_axis)
190
+ for arg, in_axis in zip(args, in_axes)
191
+ ])
192
192
  stateful_fn.make_jaxpr(state_vals, args)
193
193
  return stateful_fn.get_arg_cache_key(state_vals, args)
194
194
 
@@ -383,10 +383,7 @@ def _vmap_transform(
383
383
  stateful_fn.axis_env = axis_env
384
384
 
385
385
  # stateful function
386
- stateful_fn = StatefulFunction(
387
- _vmap_fn_for_compilation,
388
- name='vmap',
389
- )
386
+ stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
390
387
 
391
388
  @functools.wraps(f)
392
389
  def new_fn_for_vmap(
@@ -460,7 +457,10 @@ def _vmap_transform(
460
457
  # analyze vmapping axis error
461
458
  for state in state_trace.get_write_states():
462
459
  leaves = jax.tree.leaves(state.value)
463
- if any([isinstance(leaf, BatchTracer) for leaf in leaves]) and state not in out_state_to_axis:
460
+ if (
461
+ any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
462
+ and state not in out_state_to_axis
463
+ ):
464
464
  if isinstance(state, RandomState) and state in rng_sets:
465
465
  continue
466
466
  state.raise_error_with_source_info(
@@ -19,6 +19,8 @@ import unittest
19
19
  import jax
20
20
  import jax.numpy as jnp
21
21
  import numpy as np
22
+ from jax import vmap
23
+ from jax.lax import psum, pmean, pmax
22
24
 
23
25
  import brainstate
24
26
  import brainstate.augment
@@ -433,3 +435,163 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
433
435
  foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
434
436
 
435
437
  faulty_init()
438
+
439
+
440
+ class TestAxisName:
441
+ def test1(self):
442
+ def compute_stats_with_axis_name(x):
443
+ """Compute statistics using named axis operations"""
444
+ # Sum across the named axis 'batch'
445
+ total_sum = psum(x, axis_name='batch')
446
+
447
+ # Mean across the named axis 'batch'
448
+ mean_val = pmean(x, axis_name='batch')
449
+
450
+ # Max across the named axis 'batch'
451
+ max_val = pmax(x, axis_name='batch')
452
+
453
+ return {
454
+ 'sum': total_sum,
455
+ 'mean': mean_val,
456
+ 'max': max_val,
457
+ 'original': x
458
+ }
459
+
460
+ batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
461
+ print("Input batch data:", batch_data)
462
+
463
+ # vmap with axis name 'batch'
464
+ vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
465
+ result_jax = vectorized_stats_jax(batch_data)
466
+
467
+ # vmap with axis name 'batch'
468
+ vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
469
+ result = vectorized_stats(batch_data)
470
+
471
+ # vmap with axis name 'batch'
472
+ vectorized_stats_v2 = brainstate.transform.jit(
473
+ brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
474
+ )
475
+ result_v2 = vectorized_stats_v2(batch_data)
476
+
477
+ for key in result_jax.keys():
478
+ print(f" {key}: {result_jax[key]}")
479
+ assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
480
+ assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
481
+
482
+ def test_nested_vmap(self):
483
+ def nested_computation(x):
484
+ """Computation with multiple named axes"""
485
+ # Sum over 'inner' axis, then mean over 'outer' axis
486
+ inner_sum = psum(x, axis_name='inner')
487
+ outer_mean = pmean(inner_sum, axis_name='outer')
488
+ return outer_mean
489
+
490
+ # Create 2D batch data
491
+ data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
492
+ print("Input 2D data shape:", data_2d.shape)
493
+ print("Input 2D data:\n", data_2d)
494
+
495
+ # Nested vmap: first over inner dimension, then outer dimension
496
+ inner_vmap = vmap(nested_computation, axis_name='inner')
497
+ nested_vmap = vmap(inner_vmap, axis_name='outer')
498
+
499
+ result_2d = nested_vmap(data_2d)
500
+ print("Result after nested vmap:", result_2d)
501
+
502
+ inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
503
+ nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
504
+ result_2d_bst = nested_vmap_bst(data_2d)
505
+ print("Result after nested vmap:", result_2d_bst)
506
+
507
+ assert jnp.allclose(result_2d, result_2d_bst)
508
+
509
+ def _gradient_averaging_simulation_bst(self):
510
+ def loss_function(params, x, y):
511
+ """Simple quadratic loss"""
512
+ pred = params * x
513
+ return (pred - y) ** 2
514
+
515
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
516
+ """Compute gradients and average them across the batch"""
517
+ # Compute per-sample gradients
518
+ grad_fn = jax.grad(loss_function, argnums=0)
519
+ per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
520
+
521
+ # Average gradients across batch using named axis
522
+ def average_grads(grads):
523
+ return pmean(grads, axis_name='batch')
524
+
525
+ # Apply averaging with named axis
526
+ averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
527
+ return averaged_grads
528
+
529
+ # Example data
530
+ params = 2.0
531
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
532
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
533
+
534
+ print("Parameters:", params)
535
+ print("Batch X:", batch_x)
536
+ print("Batch Y:", batch_y)
537
+
538
+ # Compute individual gradients first
539
+ grad_fn = jax.grad(loss_function, argnums=0)
540
+ individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
541
+ print("Individual gradients:", individual_grads)
542
+
543
+ # Now compute averaged gradients using axis names
544
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
545
+ print("Averaged gradients:", averaged_grads)
546
+
547
+ return individual_grads, averaged_grads
548
+
549
+ def _gradient_averaging_simulation_jax(self):
550
+ def loss_function(params, x, y):
551
+ """Simple quadratic loss"""
552
+ pred = params * x
553
+ return (pred - y) ** 2
554
+
555
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
556
+ """Compute gradients and average them across the batch"""
557
+ # Compute per-sample gradients
558
+ grad_fn = jax.grad(loss_function, argnums=0)
559
+ per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
560
+
561
+ # Average gradients across batch using named axis
562
+ def average_grads(grads):
563
+ return pmean(grads, axis_name='batch')
564
+
565
+ # Apply averaging with named axis
566
+ averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
567
+ return averaged_grads
568
+
569
+ # Example data
570
+ params = 2.0
571
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
572
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
573
+
574
+ print("Parameters:", params)
575
+ print("Batch X:", batch_x)
576
+ print("Batch Y:", batch_y)
577
+
578
+ # Compute individual gradients first
579
+ grad_fn = jax.grad(loss_function, argnums=0)
580
+ individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
581
+ print("Individual gradients:", individual_grads)
582
+
583
+ # Now compute averaged gradients using axis names
584
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
585
+ print("Averaged gradients:", averaged_grads)
586
+
587
+ return individual_grads, averaged_grads
588
+
589
+ def test_gradient_averaging_simulation(self):
590
+ individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
591
+ individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
592
+ assert jnp.allclose(individual_grads, individual_grads_jax)
593
+ assert jnp.allclose(averaged_grads, averaged_grads_jax)
594
+
595
+
596
+
597
+
brainstate/nn/_common.py CHANGED
@@ -118,14 +118,14 @@ class Vmap(Module):
118
118
  This class wraps a module and applies vectorized mapping to its execution,
119
119
  allowing for efficient parallel processing across specified axes.
120
120
 
121
- Attributes:
121
+ Args:
122
122
  module (Module): The module to be vmapped.
123
- in_axes (int | None | Sequence[Any]): Specifies how to map over inputs.
124
- out_axes (Any): Specifies how to map over outputs.
125
- vmap_states (Filter | Dict[Filter, int]): Specifies which states to vmap and on which axes.
126
- vmap_out_states (Filter | Dict[Filter, int]): Specifies which output states to vmap and on which axes.
127
- axis_name (AxisName | None): Name of the axis being mapped over.
128
- axis_size (int | None): Size of the axis being mapped over.
123
+ in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
124
+ out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
125
+ vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
126
+ vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
127
+ axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
128
+ axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
129
129
  """
130
130
 
131
131
  def __init__(
@@ -138,18 +138,6 @@ class Vmap(Module):
138
138
  axis_name: AxisName | None = None,
139
139
  axis_size: int | None = None,
140
140
  ):
141
- """
142
- Initialize the Vmap instance.
143
-
144
- Args:
145
- module (Module): The module to be vmapped.
146
- in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
147
- out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
148
- vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
149
- vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
150
- axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
151
- axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
152
- """
153
141
  super().__init__()
154
142
 
155
143
  # parameters
@@ -69,27 +69,24 @@ def exp_euler_step(
69
69
  f'The input data type should be float64, float32, float16, or bfloat16 '
70
70
  f'when using Exponential Euler method. But we got {args[0].dtype}.'
71
71
  )
72
+
73
+ # drift
72
74
  dt = environ.get('dt')
73
75
  linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
74
76
  linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
75
77
  phi = u.math.exprel(dt * linear)
76
78
  x_next = args[0] + dt * phi * derivative
77
79
 
80
+ # diffusion
78
81
  if diffusion is not None:
79
- # unit checking
80
- diffusion = diffusion(*args, **kwargs)
81
- time_unit = u.get_unit(dt)
82
- drift_unit = u.get_unit(derivative)
83
- diffusion_unit = u.get_unit(diffusion)
84
- # if drift_unit.is_unitless:
85
- # assert diffusion_unit.is_unitless, 'The diffusion term should be unitless when the drift term is unitless.'
86
- # else:
87
- # u.fail_for_dimension_mismatch(
88
- # drift_unit, diffusion_unit * time_unit ** 0.5,
89
- # "Drift unit is {drift}, diffusion unit is {diffusion}, ",
90
- # drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
91
- # )
92
-
93
- # diffusion
94
- x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
82
+ diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
83
+ if u.get_dim(x_next) != u.get_dim(diffusion_part):
84
+ drift_unit = u.get_unit(x_next)
85
+ time_unit = u.get_unit(dt)
86
+ raise ValueError(
87
+ f"Drift unit is {drift_unit}, "
88
+ f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
89
+ f"but we got {u.get_unit(diffusion_part)}."
90
+ )
91
+ x_next += diffusion_part
95
92
  return x_next
brainstate/nn/_inputs.py CHANGED
@@ -547,7 +547,7 @@ def poisson_input(
547
547
  num_input,
548
548
  p,
549
549
  tar[indices].shape,
550
- # check_valid=False,
550
+ check_valid=False,
551
551
  dtype=tar.dtype
552
552
  ),
553
553
  tar_val,
@@ -384,7 +384,10 @@ class RandomState(State):
384
384
  loc = _check_py_seq(loc)
385
385
  scale = _check_py_seq(scale)
386
386
  if size is None:
387
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
387
+ size = lax.broadcast_shapes(
388
+ jnp.shape(loc) if loc is not None else (),
389
+ jnp.shape(scale) if scale is not None else ()
390
+ )
388
391
  key = self.split_key() if key is None else _formalize_key(key)
389
392
  dtype = dtype or environ.dftype()
390
393
  r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
@@ -399,7 +402,10 @@ class RandomState(State):
399
402
  loc = _check_py_seq(loc)
400
403
  scale = _check_py_seq(scale)
401
404
  if size is None:
402
- size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
405
+ size = lax.broadcast_shapes(
406
+ jnp.shape(scale) if scale is not None else (),
407
+ jnp.shape(loc) if loc is not None else ()
408
+ )
403
409
  key = self.split_key() if key is None else _formalize_key(key)
404
410
  dtype = dtype or environ.dftype()
405
411
  r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
@@ -456,7 +462,7 @@ class RandomState(State):
456
462
  dtype: DTypeLike = None):
457
463
  shape = _check_py_seq(shape)
458
464
  if size is None:
459
- size = jnp.shape(shape)
465
+ size = jnp.shape(shape) if shape is not None else ()
460
466
  key = self.split_key() if key is None else _formalize_key(key)
461
467
  dtype = dtype or environ.dftype()
462
468
  r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
@@ -477,7 +483,7 @@ class RandomState(State):
477
483
  dtype: DTypeLike = None):
478
484
  df = _check_py_seq(df)
479
485
  if size is None:
480
- size = jnp.shape(size)
486
+ size = jnp.shape(size) if size is not None else ()
481
487
  key = self.split_key() if key is None else _formalize_key(key)
482
488
  dtype = dtype or environ.dftype()
483
489
  r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
@@ -606,8 +612,8 @@ class RandomState(State):
606
612
 
607
613
  if size is None:
608
614
  size = jnp.broadcast_shapes(
609
- jnp.shape(mean),
610
- jnp.shape(sigma)
615
+ jnp.shape(mean) if mean is not None else (),
616
+ jnp.shape(sigma) if sigma is not None else ()
611
617
  )
612
618
  key = self.split_key() if key is None else _formalize_key(key)
613
619
  dtype = dtype or environ.dftype()
@@ -822,7 +828,7 @@ class RandomState(State):
822
828
  a = _check_py_seq(a)
823
829
  scale = _check_py_seq(scale)
824
830
  if size is None:
825
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
831
+ size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
826
832
  else:
827
833
  if jnp.size(a) > 1:
828
834
  raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,6 +1,6 @@
1
- brainstate/__init__.py,sha256=A_OD4cJiVu3xpthNGJh6fhRjCKsI7_Mxsow3Al4m2-w,1496
1
+ brainstate/__init__.py,sha256=24AzLahGzsgxErIIbDuFSzGFn-btHTh9dyczmpOkOvA,1496
2
2
  brainstate/_compatible_import.py,sha256=LUSZlA0APWozxM8Kf9pZrM2YbwY7X3jzVVHZInaBL7Y,4630
3
- brainstate/_state.py,sha256=o5Kk4bGwVz6Dfj9dlmZqdh6zUXcz6Tvc6WOjH9ajlIU,60716
3
+ brainstate/_state.py,sha256=OBQIu1eLH4szHw2SGduPxx24gKoU7jXgCL-TPNdI_iw,60735
4
4
  brainstate/_state_test.py,sha256=b6uvZdVRyC4n6-fYzmHNry1b-gJ6zE_kRSxGinqiHaw,1638
5
5
  brainstate/_utils.py,sha256=j-b239RHfC5BnvhGbSExQpdY21LrMEyWMSHBdNGThOI,1657
6
6
  brainstate/environ.py,sha256=VgtG0S_aR1g_1gplRWg_v2ZrcS-F6LZk35BPCBgsIIA,17660
@@ -15,8 +15,8 @@ brainstate/augment/_autograd.py,sha256=zMMLZOidQq2p96wzEOgR-MynV2JH1l1AnoJ28eVwD
15
15
  brainstate/augment/_autograd_test.py,sha256=UhNd41luca_Kj9a8byL3Dq64Ta55WU-je-LCiK_z0Vc,45060
16
16
  brainstate/augment/_eval_shape.py,sha256=mYRB5clSRHdXX0c4c5_WDIIBtRl4a9-xl8Tg60mgLBI,3854
17
17
  brainstate/augment/_eval_shape_test.py,sha256=WXrmZKmnykmYPveRfZbrDF0sFm2_PTr982yl4JM7Ebg,1390
18
- brainstate/augment/_mapping.py,sha256=AvGt8qhOdBgRJyWZn7J5JnOvxZc7XfZ_JKXaECOQjIw,43500
19
- brainstate/augment/_mapping_test.py,sha256=An5iSqrAXq_OVTvY3GFTjsB9AvFZ4j6AWNZQ_WWfra4,15319
18
+ brainstate/augment/_mapping.py,sha256=M-xl7Q-wGrftI0fJ4jO-SzCEpB7XZ_UjQfkJs_2TbCQ,43557
19
+ brainstate/augment/_mapping_test.py,sha256=tEyXioA4_v4WGRFrpgUInyKyBjAVNLonWdHKVacmphU,21946
20
20
  brainstate/augment/_random.py,sha256=bkngsIk6Wwi3e46I7YSbBjLCGAz0Q3WuadUH4mqTjbY,5348
21
21
  brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
22
22
  brainstate/compile/_ad_checkpoint.py,sha256=HM6L90HU0N84S30uJpX8wqTO0ZcDnctqmwf8qFlkDYo,9358
@@ -57,7 +57,7 @@ brainstate/init/_regular_inits_test.py,sha256=bvv0AOfLOEP0BIQIBLztKw3EPyEp7n2fHW
57
57
  brainstate/nn/__init__.py,sha256=nDIBDlzHaBy-vk4cb7FiPO6zoNe2tI5qvHIK8O7yrOU,3721
58
58
  brainstate/nn/_collective_ops.py,sha256=rWcjqaP0rW6sXdhE0fjtDi4twmq5VvhNEg8AoiQ-tDU,17390
59
59
  brainstate/nn/_collective_ops_test.py,sha256=bwq0DApcsk0_2xpxMl0_e2cGKT63g5rSngpigCm07ps,1409
60
- brainstate/nn/_common.py,sha256=Pt8P1LgE0qW3QnfX8CQQGH3yZ78RL7NkIqqYEw_z8xs,7096
60
+ brainstate/nn/_common.py,sha256=qHAOID_eeKiPUXk_ION65sbYXyF-ddH5w5BayvH8Thg,6431
61
61
  brainstate/nn/_conv.py,sha256=Zk-yj34n6CkjntcM9xpMGLTxKNfWdIWsTsoGbtdL0yU,18448
62
62
  brainstate/nn/_conv_test.py,sha256=2lcUTG7twkyhuyKwuBux-NgU8NU_W4Cp1-G8EyDJ_uk,8862
63
63
  brainstate/nn/_delay.py,sha256=l36FBgNhfL64tM3VGOsJNTtKr44HjxxtBWMFFCm3Pks,17361
@@ -68,11 +68,11 @@ brainstate/nn/_dynamics_test.py,sha256=w7AV57LdhbBNYprdFpKq8MFSCbXKVkGgp_NbL3ANX
68
68
  brainstate/nn/_elementwise.py,sha256=4czeJWGQopV49iZo8DuN_WzAbXoMC1gtqaGjlON6e7c,33291
69
69
  brainstate/nn/_elementwise_test.py,sha256=_dd9eX2ZJ7p24ahuoapCaRTZ0g1boufXMyqHFx1d4WY,5688
70
70
  brainstate/nn/_embedding.py,sha256=SaAJbgXmuJ8XlCOX9ob4yvmgh9Fk627wMguRzJMJ1H8,2138
71
- brainstate/nn/_exp_euler.py,sha256=ndDB43PM4jsZKu_zdLTZ2-ojnuNrg55LZap23oBTtdA,3493
71
+ brainstate/nn/_exp_euler.py,sha256=WTpZm-XQmsdMLNazY7wIu8eeO6pK0kRzt2lJnhEgMIk,3293
72
72
  brainstate/nn/_exp_euler_test.py,sha256=XD--qMbGHrHa3WtcPMmJKk59giDcEhSqZuBOmTNYUr8,1227
73
73
  brainstate/nn/_fixedprob.py,sha256=KGXohiU0wZnFIQDuwiRUTFsbsr8R0p8zgi5UZDuv1Bk,10004
74
74
  brainstate/nn/_fixedprob_test.py,sha256=qbRBh-MpMtEOsg492gFu2w9-FOP9z_bXapm-Q0gLLYM,3929
75
- brainstate/nn/_inputs.py,sha256=wPOfPE4IesNoDmxZJxqR0siBlJioEX-_1IZ2cltAIpM,20605
75
+ brainstate/nn/_inputs.py,sha256=hMPkx9qDBpJWPshZXLF4H1QiYK1-46wntHUIlG7cT7c,20603
76
76
  brainstate/nn/_linear.py,sha256=5WuhcqU-uBUC91vnwezQYMHPKmlZPDgIJ5UpffxoX1I,14472
77
77
  brainstate/nn/_linear_mv.py,sha256=6hDXx4yPqRSa7uIsW9f9eJuy23dcXN9Mp2_lSvw8BDA,2635
78
78
  brainstate/nn/_linear_mv_test.py,sha256=ZCM1Zy6mImQfCfdZOGnTwkiLLPXK5yalv1Ts9sWZuPA,3864
@@ -111,7 +111,7 @@ brainstate/random/_rand_funs.py,sha256=c4xiY2NeMizSslxbWOa-QJZ3h-LDfsgH4fbvBMSNL
111
111
  brainstate/random/_rand_funs_test.py,sha256=G8BuxDjBSeE-Mh7KuwVk6mPQhHE0m1R5HDFljEOIdzg,20669
112
112
  brainstate/random/_rand_seed.py,sha256=1ZdfFZWyOhpd72EDdEDmpkp3yoLVwdv-sGI9BwiZfzI,5949
113
113
  brainstate/random/_rand_seed_test.py,sha256=waXXfch57X1XE1zDnCRokT6ziZOK0g-lYE80o6epDYM,1536
114
- brainstate/random/_rand_state.py,sha256=dTO7wqmuTYRdPy7ItsrK-7aNt5QQXOsZ4XwiH7mzmy8,55170
114
+ brainstate/random/_rand_state.py,sha256=hQ_govRxfIgsmtFe_V2B-jftqiMWRWrd9wkviPFOziY,55523
115
115
  brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
116
116
  brainstate/util/__init__.py,sha256=6efwr63osmqviNU_6_Nufag19PwxRDFvDAZrq6sH5yo,1555
117
117
  brainstate/util/_pretty_pytree_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
@@ -124,8 +124,8 @@ brainstate/util/pretty_repr.py,sha256=7Xp7IFNUeP7cGlpvwwJyBslbQVnXEqC1I6neV1Jx1S
124
124
  brainstate/util/pretty_table.py,sha256=uJVaamFGQ4nKP8TkEGPWXHpzjMecDo2q1Ah6XtRjdPY,108117
125
125
  brainstate/util/scaling.py,sha256=U6DM-afPrLejiGqo1Nla7z4YbTBVicctsBEweurr_mk,7524
126
126
  brainstate/util/struct.py,sha256=7HbbQNrZ3zxYw93MU1bUZ9ZPBKftYOVKuXEochLSErw,17479
127
- brainstate-0.1.4.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
128
- brainstate-0.1.4.dist-info/METADATA,sha256=EWarKpuYFIN0ITFn8lwsUy0h9-KjrVPDwcvjQMHFRto,4135
129
- brainstate-0.1.4.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
130
- brainstate-0.1.4.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
131
- brainstate-0.1.4.dist-info/RECORD,,
127
+ brainstate-0.1.5.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
128
+ brainstate-0.1.5.dist-info/METADATA,sha256=0rd8rcVbLHdAkaLX9T2-BA2aRfq4D2ToDcmt0pp0wwg,4135
129
+ brainstate-0.1.5.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
130
+ brainstate-0.1.5.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
131
+ brainstate-0.1.5.dist-info/RECORD,,