brainstate 0.1.4__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +1 -0
- brainstate/augment/_mapping.py +9 -9
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/_inputs.py +1 -1
- brainstate/random/_rand_state.py +13 -7
- {brainstate-0.1.4.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
- {brainstate-0.1.4.dist-info → brainstate-0.1.5.dist-info}/RECORD +13 -13
- {brainstate-0.1.4.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_state.py
CHANGED
brainstate/augment/_mapping.py
CHANGED
@@ -185,10 +185,10 @@ def _compile_stateful_function(
|
|
185
185
|
if isinstance(in_axes, int):
|
186
186
|
args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
|
187
187
|
elif isinstance(in_axes, tuple):
|
188
|
-
args = tuple(
|
189
|
-
|
190
|
-
|
191
|
-
)
|
188
|
+
args = tuple([
|
189
|
+
arg if in_axis is None else _remove_axis(arg, in_axis)
|
190
|
+
for arg, in_axis in zip(args, in_axes)
|
191
|
+
])
|
192
192
|
stateful_fn.make_jaxpr(state_vals, args)
|
193
193
|
return stateful_fn.get_arg_cache_key(state_vals, args)
|
194
194
|
|
@@ -383,10 +383,7 @@ def _vmap_transform(
|
|
383
383
|
stateful_fn.axis_env = axis_env
|
384
384
|
|
385
385
|
# stateful function
|
386
|
-
stateful_fn = StatefulFunction(
|
387
|
-
_vmap_fn_for_compilation,
|
388
|
-
name='vmap',
|
389
|
-
)
|
386
|
+
stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
|
390
387
|
|
391
388
|
@functools.wraps(f)
|
392
389
|
def new_fn_for_vmap(
|
@@ -460,7 +457,10 @@ def _vmap_transform(
|
|
460
457
|
# analyze vmapping axis error
|
461
458
|
for state in state_trace.get_write_states():
|
462
459
|
leaves = jax.tree.leaves(state.value)
|
463
|
-
if
|
460
|
+
if (
|
461
|
+
any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
|
462
|
+
and state not in out_state_to_axis
|
463
|
+
):
|
464
464
|
if isinstance(state, RandomState) and state in rng_sets:
|
465
465
|
continue
|
466
466
|
state.raise_error_with_source_info(
|
@@ -19,6 +19,8 @@ import unittest
|
|
19
19
|
import jax
|
20
20
|
import jax.numpy as jnp
|
21
21
|
import numpy as np
|
22
|
+
from jax import vmap
|
23
|
+
from jax.lax import psum, pmean, pmax
|
22
24
|
|
23
25
|
import brainstate
|
24
26
|
import brainstate.augment
|
@@ -433,3 +435,163 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
433
435
|
foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
|
434
436
|
|
435
437
|
faulty_init()
|
438
|
+
|
439
|
+
|
440
|
+
class TestAxisName:
|
441
|
+
def test1(self):
|
442
|
+
def compute_stats_with_axis_name(x):
|
443
|
+
"""Compute statistics using named axis operations"""
|
444
|
+
# Sum across the named axis 'batch'
|
445
|
+
total_sum = psum(x, axis_name='batch')
|
446
|
+
|
447
|
+
# Mean across the named axis 'batch'
|
448
|
+
mean_val = pmean(x, axis_name='batch')
|
449
|
+
|
450
|
+
# Max across the named axis 'batch'
|
451
|
+
max_val = pmax(x, axis_name='batch')
|
452
|
+
|
453
|
+
return {
|
454
|
+
'sum': total_sum,
|
455
|
+
'mean': mean_val,
|
456
|
+
'max': max_val,
|
457
|
+
'original': x
|
458
|
+
}
|
459
|
+
|
460
|
+
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
461
|
+
print("Input batch data:", batch_data)
|
462
|
+
|
463
|
+
# vmap with axis name 'batch'
|
464
|
+
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
465
|
+
result_jax = vectorized_stats_jax(batch_data)
|
466
|
+
|
467
|
+
# vmap with axis name 'batch'
|
468
|
+
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
469
|
+
result = vectorized_stats(batch_data)
|
470
|
+
|
471
|
+
# vmap with axis name 'batch'
|
472
|
+
vectorized_stats_v2 = brainstate.transform.jit(
|
473
|
+
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
474
|
+
)
|
475
|
+
result_v2 = vectorized_stats_v2(batch_data)
|
476
|
+
|
477
|
+
for key in result_jax.keys():
|
478
|
+
print(f" {key}: {result_jax[key]}")
|
479
|
+
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
480
|
+
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
481
|
+
|
482
|
+
def test_nested_vmap(self):
|
483
|
+
def nested_computation(x):
|
484
|
+
"""Computation with multiple named axes"""
|
485
|
+
# Sum over 'inner' axis, then mean over 'outer' axis
|
486
|
+
inner_sum = psum(x, axis_name='inner')
|
487
|
+
outer_mean = pmean(inner_sum, axis_name='outer')
|
488
|
+
return outer_mean
|
489
|
+
|
490
|
+
# Create 2D batch data
|
491
|
+
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
492
|
+
print("Input 2D data shape:", data_2d.shape)
|
493
|
+
print("Input 2D data:\n", data_2d)
|
494
|
+
|
495
|
+
# Nested vmap: first over inner dimension, then outer dimension
|
496
|
+
inner_vmap = vmap(nested_computation, axis_name='inner')
|
497
|
+
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
498
|
+
|
499
|
+
result_2d = nested_vmap(data_2d)
|
500
|
+
print("Result after nested vmap:", result_2d)
|
501
|
+
|
502
|
+
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
503
|
+
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
504
|
+
result_2d_bst = nested_vmap_bst(data_2d)
|
505
|
+
print("Result after nested vmap:", result_2d_bst)
|
506
|
+
|
507
|
+
assert jnp.allclose(result_2d, result_2d_bst)
|
508
|
+
|
509
|
+
def _gradient_averaging_simulation_bst(self):
|
510
|
+
def loss_function(params, x, y):
|
511
|
+
"""Simple quadratic loss"""
|
512
|
+
pred = params * x
|
513
|
+
return (pred - y) ** 2
|
514
|
+
|
515
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
516
|
+
"""Compute gradients and average them across the batch"""
|
517
|
+
# Compute per-sample gradients
|
518
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
519
|
+
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
520
|
+
|
521
|
+
# Average gradients across batch using named axis
|
522
|
+
def average_grads(grads):
|
523
|
+
return pmean(grads, axis_name='batch')
|
524
|
+
|
525
|
+
# Apply averaging with named axis
|
526
|
+
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
527
|
+
return averaged_grads
|
528
|
+
|
529
|
+
# Example data
|
530
|
+
params = 2.0
|
531
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
532
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
533
|
+
|
534
|
+
print("Parameters:", params)
|
535
|
+
print("Batch X:", batch_x)
|
536
|
+
print("Batch Y:", batch_y)
|
537
|
+
|
538
|
+
# Compute individual gradients first
|
539
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
540
|
+
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
541
|
+
print("Individual gradients:", individual_grads)
|
542
|
+
|
543
|
+
# Now compute averaged gradients using axis names
|
544
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
545
|
+
print("Averaged gradients:", averaged_grads)
|
546
|
+
|
547
|
+
return individual_grads, averaged_grads
|
548
|
+
|
549
|
+
def _gradient_averaging_simulation_jax(self):
|
550
|
+
def loss_function(params, x, y):
|
551
|
+
"""Simple quadratic loss"""
|
552
|
+
pred = params * x
|
553
|
+
return (pred - y) ** 2
|
554
|
+
|
555
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
556
|
+
"""Compute gradients and average them across the batch"""
|
557
|
+
# Compute per-sample gradients
|
558
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
559
|
+
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
560
|
+
|
561
|
+
# Average gradients across batch using named axis
|
562
|
+
def average_grads(grads):
|
563
|
+
return pmean(grads, axis_name='batch')
|
564
|
+
|
565
|
+
# Apply averaging with named axis
|
566
|
+
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
567
|
+
return averaged_grads
|
568
|
+
|
569
|
+
# Example data
|
570
|
+
params = 2.0
|
571
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
572
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
573
|
+
|
574
|
+
print("Parameters:", params)
|
575
|
+
print("Batch X:", batch_x)
|
576
|
+
print("Batch Y:", batch_y)
|
577
|
+
|
578
|
+
# Compute individual gradients first
|
579
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
580
|
+
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
581
|
+
print("Individual gradients:", individual_grads)
|
582
|
+
|
583
|
+
# Now compute averaged gradients using axis names
|
584
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
585
|
+
print("Averaged gradients:", averaged_grads)
|
586
|
+
|
587
|
+
return individual_grads, averaged_grads
|
588
|
+
|
589
|
+
def test_gradient_averaging_simulation(self):
|
590
|
+
individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
|
591
|
+
individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
|
592
|
+
assert jnp.allclose(individual_grads, individual_grads_jax)
|
593
|
+
assert jnp.allclose(averaged_grads, averaged_grads_jax)
|
594
|
+
|
595
|
+
|
596
|
+
|
597
|
+
|
brainstate/nn/_common.py
CHANGED
@@ -118,14 +118,14 @@ class Vmap(Module):
|
|
118
118
|
This class wraps a module and applies vectorized mapping to its execution,
|
119
119
|
allowing for efficient parallel processing across specified axes.
|
120
120
|
|
121
|
-
|
121
|
+
Args:
|
122
122
|
module (Module): The module to be vmapped.
|
123
|
-
in_axes (int | None | Sequence[Any]): Specifies how to map over inputs.
|
124
|
-
out_axes (Any): Specifies how to map over outputs.
|
125
|
-
vmap_states (Filter | Dict[Filter, int]): Specifies which states to vmap and on which axes.
|
126
|
-
vmap_out_states (Filter | Dict[Filter, int]): Specifies which output states to vmap and on which axes.
|
127
|
-
axis_name (AxisName | None): Name of the axis being mapped over.
|
128
|
-
axis_size (int | None): Size of the axis being mapped over.
|
123
|
+
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
124
|
+
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
125
|
+
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
126
|
+
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
127
|
+
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
128
|
+
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
129
129
|
"""
|
130
130
|
|
131
131
|
def __init__(
|
@@ -138,18 +138,6 @@ class Vmap(Module):
|
|
138
138
|
axis_name: AxisName | None = None,
|
139
139
|
axis_size: int | None = None,
|
140
140
|
):
|
141
|
-
"""
|
142
|
-
Initialize the Vmap instance.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
module (Module): The module to be vmapped.
|
146
|
-
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
147
|
-
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
148
|
-
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
149
|
-
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
150
|
-
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
151
|
-
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
152
|
-
"""
|
153
141
|
super().__init__()
|
154
142
|
|
155
143
|
# parameters
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -69,27 +69,24 @@ def exp_euler_step(
|
|
69
69
|
f'The input data type should be float64, float32, float16, or bfloat16 '
|
70
70
|
f'when using Exponential Euler method. But we got {args[0].dtype}.'
|
71
71
|
)
|
72
|
+
|
73
|
+
# drift
|
72
74
|
dt = environ.get('dt')
|
73
75
|
linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
74
76
|
linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
|
75
77
|
phi = u.math.exprel(dt * linear)
|
76
78
|
x_next = args[0] + dt * phi * derivative
|
77
79
|
|
80
|
+
# diffusion
|
78
81
|
if diffusion is not None:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
# "Drift unit is {drift}, diffusion unit is {diffusion}, ",
|
90
|
-
# drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
|
91
|
-
# )
|
92
|
-
|
93
|
-
# diffusion
|
94
|
-
x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
|
82
|
+
diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
|
83
|
+
if u.get_dim(x_next) != u.get_dim(diffusion_part):
|
84
|
+
drift_unit = u.get_unit(x_next)
|
85
|
+
time_unit = u.get_unit(dt)
|
86
|
+
raise ValueError(
|
87
|
+
f"Drift unit is {drift_unit}, "
|
88
|
+
f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
|
89
|
+
f"but we got {u.get_unit(diffusion_part)}."
|
90
|
+
)
|
91
|
+
x_next += diffusion_part
|
95
92
|
return x_next
|
brainstate/nn/_inputs.py
CHANGED
brainstate/random/_rand_state.py
CHANGED
@@ -384,7 +384,10 @@ class RandomState(State):
|
|
384
384
|
loc = _check_py_seq(loc)
|
385
385
|
scale = _check_py_seq(scale)
|
386
386
|
if size is None:
|
387
|
-
size = lax.broadcast_shapes(
|
387
|
+
size = lax.broadcast_shapes(
|
388
|
+
jnp.shape(loc) if loc is not None else (),
|
389
|
+
jnp.shape(scale) if scale is not None else ()
|
390
|
+
)
|
388
391
|
key = self.split_key() if key is None else _formalize_key(key)
|
389
392
|
dtype = dtype or environ.dftype()
|
390
393
|
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
|
@@ -399,7 +402,10 @@ class RandomState(State):
|
|
399
402
|
loc = _check_py_seq(loc)
|
400
403
|
scale = _check_py_seq(scale)
|
401
404
|
if size is None:
|
402
|
-
size = lax.broadcast_shapes(
|
405
|
+
size = lax.broadcast_shapes(
|
406
|
+
jnp.shape(scale) if scale is not None else (),
|
407
|
+
jnp.shape(loc) if loc is not None else ()
|
408
|
+
)
|
403
409
|
key = self.split_key() if key is None else _formalize_key(key)
|
404
410
|
dtype = dtype or environ.dftype()
|
405
411
|
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
@@ -456,7 +462,7 @@ class RandomState(State):
|
|
456
462
|
dtype: DTypeLike = None):
|
457
463
|
shape = _check_py_seq(shape)
|
458
464
|
if size is None:
|
459
|
-
size = jnp.shape(shape)
|
465
|
+
size = jnp.shape(shape) if shape is not None else ()
|
460
466
|
key = self.split_key() if key is None else _formalize_key(key)
|
461
467
|
dtype = dtype or environ.dftype()
|
462
468
|
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
@@ -477,7 +483,7 @@ class RandomState(State):
|
|
477
483
|
dtype: DTypeLike = None):
|
478
484
|
df = _check_py_seq(df)
|
479
485
|
if size is None:
|
480
|
-
size = jnp.shape(size)
|
486
|
+
size = jnp.shape(size) if size is not None else ()
|
481
487
|
key = self.split_key() if key is None else _formalize_key(key)
|
482
488
|
dtype = dtype or environ.dftype()
|
483
489
|
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
@@ -606,8 +612,8 @@ class RandomState(State):
|
|
606
612
|
|
607
613
|
if size is None:
|
608
614
|
size = jnp.broadcast_shapes(
|
609
|
-
jnp.shape(mean),
|
610
|
-
jnp.shape(sigma)
|
615
|
+
jnp.shape(mean) if mean is not None else (),
|
616
|
+
jnp.shape(sigma) if sigma is not None else ()
|
611
617
|
)
|
612
618
|
key = self.split_key() if key is None else _formalize_key(key)
|
613
619
|
dtype = dtype or environ.dftype()
|
@@ -822,7 +828,7 @@ class RandomState(State):
|
|
822
828
|
a = _check_py_seq(a)
|
823
829
|
scale = _check_py_seq(scale)
|
824
830
|
if size is None:
|
825
|
-
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
|
831
|
+
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
|
826
832
|
else:
|
827
833
|
if jnp.size(a) > 1:
|
828
834
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
@@ -1,6 +1,6 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=24AzLahGzsgxErIIbDuFSzGFn-btHTh9dyczmpOkOvA,1496
|
2
2
|
brainstate/_compatible_import.py,sha256=LUSZlA0APWozxM8Kf9pZrM2YbwY7X3jzVVHZInaBL7Y,4630
|
3
|
-
brainstate/_state.py,sha256=
|
3
|
+
brainstate/_state.py,sha256=OBQIu1eLH4szHw2SGduPxx24gKoU7jXgCL-TPNdI_iw,60735
|
4
4
|
brainstate/_state_test.py,sha256=b6uvZdVRyC4n6-fYzmHNry1b-gJ6zE_kRSxGinqiHaw,1638
|
5
5
|
brainstate/_utils.py,sha256=j-b239RHfC5BnvhGbSExQpdY21LrMEyWMSHBdNGThOI,1657
|
6
6
|
brainstate/environ.py,sha256=VgtG0S_aR1g_1gplRWg_v2ZrcS-F6LZk35BPCBgsIIA,17660
|
@@ -15,8 +15,8 @@ brainstate/augment/_autograd.py,sha256=zMMLZOidQq2p96wzEOgR-MynV2JH1l1AnoJ28eVwD
|
|
15
15
|
brainstate/augment/_autograd_test.py,sha256=UhNd41luca_Kj9a8byL3Dq64Ta55WU-je-LCiK_z0Vc,45060
|
16
16
|
brainstate/augment/_eval_shape.py,sha256=mYRB5clSRHdXX0c4c5_WDIIBtRl4a9-xl8Tg60mgLBI,3854
|
17
17
|
brainstate/augment/_eval_shape_test.py,sha256=WXrmZKmnykmYPveRfZbrDF0sFm2_PTr982yl4JM7Ebg,1390
|
18
|
-
brainstate/augment/_mapping.py,sha256=
|
19
|
-
brainstate/augment/_mapping_test.py,sha256=
|
18
|
+
brainstate/augment/_mapping.py,sha256=M-xl7Q-wGrftI0fJ4jO-SzCEpB7XZ_UjQfkJs_2TbCQ,43557
|
19
|
+
brainstate/augment/_mapping_test.py,sha256=tEyXioA4_v4WGRFrpgUInyKyBjAVNLonWdHKVacmphU,21946
|
20
20
|
brainstate/augment/_random.py,sha256=bkngsIk6Wwi3e46I7YSbBjLCGAz0Q3WuadUH4mqTjbY,5348
|
21
21
|
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
22
22
|
brainstate/compile/_ad_checkpoint.py,sha256=HM6L90HU0N84S30uJpX8wqTO0ZcDnctqmwf8qFlkDYo,9358
|
@@ -57,7 +57,7 @@ brainstate/init/_regular_inits_test.py,sha256=bvv0AOfLOEP0BIQIBLztKw3EPyEp7n2fHW
|
|
57
57
|
brainstate/nn/__init__.py,sha256=nDIBDlzHaBy-vk4cb7FiPO6zoNe2tI5qvHIK8O7yrOU,3721
|
58
58
|
brainstate/nn/_collective_ops.py,sha256=rWcjqaP0rW6sXdhE0fjtDi4twmq5VvhNEg8AoiQ-tDU,17390
|
59
59
|
brainstate/nn/_collective_ops_test.py,sha256=bwq0DApcsk0_2xpxMl0_e2cGKT63g5rSngpigCm07ps,1409
|
60
|
-
brainstate/nn/_common.py,sha256=
|
60
|
+
brainstate/nn/_common.py,sha256=qHAOID_eeKiPUXk_ION65sbYXyF-ddH5w5BayvH8Thg,6431
|
61
61
|
brainstate/nn/_conv.py,sha256=Zk-yj34n6CkjntcM9xpMGLTxKNfWdIWsTsoGbtdL0yU,18448
|
62
62
|
brainstate/nn/_conv_test.py,sha256=2lcUTG7twkyhuyKwuBux-NgU8NU_W4Cp1-G8EyDJ_uk,8862
|
63
63
|
brainstate/nn/_delay.py,sha256=l36FBgNhfL64tM3VGOsJNTtKr44HjxxtBWMFFCm3Pks,17361
|
@@ -68,11 +68,11 @@ brainstate/nn/_dynamics_test.py,sha256=w7AV57LdhbBNYprdFpKq8MFSCbXKVkGgp_NbL3ANX
|
|
68
68
|
brainstate/nn/_elementwise.py,sha256=4czeJWGQopV49iZo8DuN_WzAbXoMC1gtqaGjlON6e7c,33291
|
69
69
|
brainstate/nn/_elementwise_test.py,sha256=_dd9eX2ZJ7p24ahuoapCaRTZ0g1boufXMyqHFx1d4WY,5688
|
70
70
|
brainstate/nn/_embedding.py,sha256=SaAJbgXmuJ8XlCOX9ob4yvmgh9Fk627wMguRzJMJ1H8,2138
|
71
|
-
brainstate/nn/_exp_euler.py,sha256=
|
71
|
+
brainstate/nn/_exp_euler.py,sha256=WTpZm-XQmsdMLNazY7wIu8eeO6pK0kRzt2lJnhEgMIk,3293
|
72
72
|
brainstate/nn/_exp_euler_test.py,sha256=XD--qMbGHrHa3WtcPMmJKk59giDcEhSqZuBOmTNYUr8,1227
|
73
73
|
brainstate/nn/_fixedprob.py,sha256=KGXohiU0wZnFIQDuwiRUTFsbsr8R0p8zgi5UZDuv1Bk,10004
|
74
74
|
brainstate/nn/_fixedprob_test.py,sha256=qbRBh-MpMtEOsg492gFu2w9-FOP9z_bXapm-Q0gLLYM,3929
|
75
|
-
brainstate/nn/_inputs.py,sha256=
|
75
|
+
brainstate/nn/_inputs.py,sha256=hMPkx9qDBpJWPshZXLF4H1QiYK1-46wntHUIlG7cT7c,20603
|
76
76
|
brainstate/nn/_linear.py,sha256=5WuhcqU-uBUC91vnwezQYMHPKmlZPDgIJ5UpffxoX1I,14472
|
77
77
|
brainstate/nn/_linear_mv.py,sha256=6hDXx4yPqRSa7uIsW9f9eJuy23dcXN9Mp2_lSvw8BDA,2635
|
78
78
|
brainstate/nn/_linear_mv_test.py,sha256=ZCM1Zy6mImQfCfdZOGnTwkiLLPXK5yalv1Ts9sWZuPA,3864
|
@@ -111,7 +111,7 @@ brainstate/random/_rand_funs.py,sha256=c4xiY2NeMizSslxbWOa-QJZ3h-LDfsgH4fbvBMSNL
|
|
111
111
|
brainstate/random/_rand_funs_test.py,sha256=G8BuxDjBSeE-Mh7KuwVk6mPQhHE0m1R5HDFljEOIdzg,20669
|
112
112
|
brainstate/random/_rand_seed.py,sha256=1ZdfFZWyOhpd72EDdEDmpkp3yoLVwdv-sGI9BwiZfzI,5949
|
113
113
|
brainstate/random/_rand_seed_test.py,sha256=waXXfch57X1XE1zDnCRokT6ziZOK0g-lYE80o6epDYM,1536
|
114
|
-
brainstate/random/_rand_state.py,sha256=
|
114
|
+
brainstate/random/_rand_state.py,sha256=hQ_govRxfIgsmtFe_V2B-jftqiMWRWrd9wkviPFOziY,55523
|
115
115
|
brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
|
116
116
|
brainstate/util/__init__.py,sha256=6efwr63osmqviNU_6_Nufag19PwxRDFvDAZrq6sH5yo,1555
|
117
117
|
brainstate/util/_pretty_pytree_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
|
@@ -124,8 +124,8 @@ brainstate/util/pretty_repr.py,sha256=7Xp7IFNUeP7cGlpvwwJyBslbQVnXEqC1I6neV1Jx1S
|
|
124
124
|
brainstate/util/pretty_table.py,sha256=uJVaamFGQ4nKP8TkEGPWXHpzjMecDo2q1Ah6XtRjdPY,108117
|
125
125
|
brainstate/util/scaling.py,sha256=U6DM-afPrLejiGqo1Nla7z4YbTBVicctsBEweurr_mk,7524
|
126
126
|
brainstate/util/struct.py,sha256=7HbbQNrZ3zxYw93MU1bUZ9ZPBKftYOVKuXEochLSErw,17479
|
127
|
-
brainstate-0.1.
|
128
|
-
brainstate-0.1.
|
129
|
-
brainstate-0.1.
|
130
|
-
brainstate-0.1.
|
131
|
-
brainstate-0.1.
|
127
|
+
brainstate-0.1.5.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
128
|
+
brainstate-0.1.5.dist-info/METADATA,sha256=0rd8rcVbLHdAkaLX9T2-BA2aRfq4D2ToDcmt0pp0wwg,4135
|
129
|
+
brainstate-0.1.5.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
130
|
+
brainstate-0.1.5.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
131
|
+
brainstate-0.1.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|