brainstate 0.1.0.post20250210__py2.py3-none-any.whl → 0.1.0.post20250212__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +54 -73
- brainstate/compile/_make_jaxpr.py +1 -1
- brainstate/graph/_graph_node.py +2 -2
- brainstate/nn/_interaction/_conv.py +12 -10
- brainstate/nn/_module.py +2 -1
- brainstate/surrogate.py +180 -35
- brainstate/util/_dict.py +163 -6
- brainstate/util/_pretty_repr.py +133 -10
- {brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/RECORD +13 -13
- {brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -19,8 +19,10 @@ import contextlib
|
|
19
19
|
import dataclasses
|
20
20
|
import threading
|
21
21
|
from functools import wraps, partial
|
22
|
-
from typing import (
|
23
|
-
|
22
|
+
from typing import (
|
23
|
+
Any, Union, Callable, Generic, Mapping,
|
24
|
+
TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence
|
25
|
+
)
|
24
26
|
|
25
27
|
import jax
|
26
28
|
import numpy as np
|
@@ -28,7 +30,7 @@ from jax.api_util import shaped_abstractify
|
|
28
30
|
from jax.extend import source_info_util
|
29
31
|
|
30
32
|
from brainstate.typing import ArrayLike, PyTree, Missing
|
31
|
-
from brainstate.util import DictManager,
|
33
|
+
from brainstate.util import DictManager, PrettyObject
|
32
34
|
|
33
35
|
__all__ = [
|
34
36
|
'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
|
@@ -184,7 +186,7 @@ def _get_trace_stack_level() -> int:
|
|
184
186
|
return len(TRACE_CONTEXT.state_stack)
|
185
187
|
|
186
188
|
|
187
|
-
class State(Generic[A],
|
189
|
+
class State(Generic[A], PrettyObject):
|
188
190
|
"""
|
189
191
|
The pointer to specify the dynamical data.
|
190
192
|
|
@@ -434,45 +436,25 @@ class State(Generic[A], PrettyRepr):
|
|
434
436
|
del metadata['_value']
|
435
437
|
return TreefyState(type(self), self._value, **metadata)
|
436
438
|
|
437
|
-
def
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
if name == '_value':
|
457
|
-
name = 'value'
|
458
|
-
if name == '_name':
|
459
|
-
if value is None:
|
460
|
-
continue
|
461
|
-
else:
|
462
|
-
name = 'name'
|
463
|
-
if name == 'tag' and value is None:
|
464
|
-
continue
|
465
|
-
if name in ['_level', '_source_info', '_been_writen']:
|
466
|
-
continue
|
467
|
-
children[name] = value
|
468
|
-
|
469
|
-
import treescope # type: ignore[import-not-found,import-untyped]
|
470
|
-
return treescope.repr_lib.render_object_constructor(
|
471
|
-
object_type=type(self),
|
472
|
-
attributes=children,
|
473
|
-
path=path,
|
474
|
-
subtree_renderer=subtree_renderer,
|
475
|
-
)
|
439
|
+
def __pretty_repr_item__(self, k, v):
|
440
|
+
if k in ['_level', '_source_info', '_been_writen']:
|
441
|
+
return None, None
|
442
|
+
if k == '_value':
|
443
|
+
return 'value', v
|
444
|
+
|
445
|
+
if k == '_name':
|
446
|
+
if self.name is None:
|
447
|
+
return None, None
|
448
|
+
else:
|
449
|
+
return 'name', v
|
450
|
+
|
451
|
+
if k == 'tag':
|
452
|
+
if self.tag is None:
|
453
|
+
return None, None
|
454
|
+
else:
|
455
|
+
return 'tag', v
|
456
|
+
|
457
|
+
return k, v
|
476
458
|
|
477
459
|
def __eq__(self, other: object) -> bool:
|
478
460
|
return type(self) is type(other) and vars(other) == vars(self)
|
@@ -483,6 +465,25 @@ class State(Generic[A], PrettyRepr):
|
|
483
465
|
"""
|
484
466
|
return hash(id(self))
|
485
467
|
|
468
|
+
def numel(self) -> int:
|
469
|
+
"""
|
470
|
+
Calculate the total number of elements in the state value.
|
471
|
+
|
472
|
+
This method traverses the state's value, which may be a nested structure (PyTree),
|
473
|
+
and computes the sum of sizes of all leaf nodes.
|
474
|
+
|
475
|
+
Returns:
|
476
|
+
int: The total number of elements across all arrays in the state value.
|
477
|
+
For scalar values, this will be 1. For arrays or nested structures,
|
478
|
+
it will be the sum of the sizes of all contained arrays.
|
479
|
+
|
480
|
+
Note:
|
481
|
+
This method uses jax.tree.leaves to flatten any nested structure in the state value,
|
482
|
+
and jax.numpy.size to compute the size of each leaf node.
|
483
|
+
"""
|
484
|
+
sizes = [jax.numpy.size(val) for val in jax.tree.leaves(self._value)]
|
485
|
+
return sum(sizes)
|
486
|
+
|
486
487
|
|
487
488
|
def record_state_init(st: State[A]):
|
488
489
|
trace: Catcher
|
@@ -809,7 +810,7 @@ class StateTraceStack(Generic[A]):
|
|
809
810
|
return StateTraceStack().merge(self, other)
|
810
811
|
|
811
812
|
|
812
|
-
class TreefyState(Generic[A],
|
813
|
+
class TreefyState(Generic[A], PrettyObject):
|
813
814
|
"""
|
814
815
|
The state as a pytree.
|
815
816
|
"""
|
@@ -831,35 +832,15 @@ class TreefyState(Generic[A], PrettyRepr):
|
|
831
832
|
|
832
833
|
def __delattr__(self, name: str) -> None: ...
|
833
834
|
|
834
|
-
def
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
else:
|
844
|
-
name = 'name'
|
845
|
-
if name in ['_level', '_source_info', 'type']:
|
846
|
-
continue
|
847
|
-
yield PrettyAttr(name, repr(value))
|
848
|
-
|
849
|
-
def __treescope_repr__(self, path, subtree_renderer):
|
850
|
-
children = {'type': self.type}
|
851
|
-
for name, value in vars(self).items():
|
852
|
-
if name == 'type':
|
853
|
-
continue
|
854
|
-
children[name] = value
|
855
|
-
|
856
|
-
import treescope # type: ignore[import-not-found,import-untyped]
|
857
|
-
return treescope.repr_lib.render_object_constructor(
|
858
|
-
object_type=type(self),
|
859
|
-
attributes=children,
|
860
|
-
path=path,
|
861
|
-
subtree_renderer=subtree_renderer,
|
862
|
-
)
|
835
|
+
def __pretty_repr_item__(self, k, v):
|
836
|
+
if k in ['_level', '_source_info', '_been_writen']:
|
837
|
+
return None, None
|
838
|
+
if k == '_value':
|
839
|
+
return 'value', v
|
840
|
+
|
841
|
+
if k == '_name':
|
842
|
+
return (None, None) if v is None else ('name', v)
|
843
|
+
return k, v
|
863
844
|
|
864
845
|
def replace(self, value: B) -> TreefyState[B]:
|
865
846
|
"""
|
@@ -206,7 +206,7 @@ class StatefulFunction(object):
|
|
206
206
|
self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
|
207
207
|
|
208
208
|
def __repr__(self) -> str:
|
209
|
-
return (f"{self.__class__.__name__}(
|
209
|
+
return (f"{self.__class__.__name__}("
|
210
210
|
f"static_argnums={self.static_argnums}, "
|
211
211
|
f"axis_env={self.axis_env}, "
|
212
212
|
f"abstracted_axes={self.abstracted_axes}, "
|
brainstate/graph/_graph_node.py
CHANGED
@@ -27,7 +27,7 @@ import numpy as np
|
|
27
27
|
|
28
28
|
from brainstate._state import State, TreefyState
|
29
29
|
from brainstate.typing import Key
|
30
|
-
from brainstate.util._pretty_repr import PrettyRepr,
|
30
|
+
from brainstate.util._pretty_repr import PrettyRepr, yield_unique_pretty_repr_items, PrettyType, PrettyAttr
|
31
31
|
from ._graph_operation import register_graph_node_type
|
32
32
|
|
33
33
|
__all__ = [
|
@@ -88,7 +88,7 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
88
88
|
"""
|
89
89
|
Pretty repr for the object.
|
90
90
|
"""
|
91
|
-
yield from
|
91
|
+
yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
|
92
92
|
|
93
93
|
def __treescope_repr__(self, path, subtree_renderer):
|
94
94
|
"""
|
@@ -193,16 +193,18 @@ class _Conv(_BaseConv):
|
|
193
193
|
name: str = None,
|
194
194
|
param_type: type = ParamState,
|
195
195
|
):
|
196
|
-
super().__init__(
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
196
|
+
super().__init__(
|
197
|
+
in_size=in_size,
|
198
|
+
out_channels=out_channels,
|
199
|
+
kernel_size=kernel_size,
|
200
|
+
stride=stride,
|
201
|
+
padding=padding,
|
202
|
+
lhs_dilation=lhs_dilation,
|
203
|
+
rhs_dilation=rhs_dilation,
|
204
|
+
groups=groups,
|
205
|
+
w_mask=w_mask,
|
206
|
+
name=name
|
207
|
+
)
|
206
208
|
|
207
209
|
self.w_initializer = w_init
|
208
210
|
self.b_initializer = b_init
|
brainstate/nn/_module.py
CHANGED
@@ -294,7 +294,8 @@ class Sequential(Module):
|
|
294
294
|
# the input and output shape
|
295
295
|
if first.in_size is not None:
|
296
296
|
self.in_size = first.in_size
|
297
|
-
|
297
|
+
if in_size is not None:
|
298
|
+
self.out_size = tuple(in_size)
|
298
299
|
|
299
300
|
def update(self, x):
|
300
301
|
"""Update function of a sequential model.
|
brainstate/surrogate.py
CHANGED
@@ -21,6 +21,8 @@ import jax.numpy as jnp
|
|
21
21
|
import jax.scipy as sci
|
22
22
|
from jax.interpreters import batching, ad, mlir
|
23
23
|
|
24
|
+
from brainstate.util._dict import PrettyObject
|
25
|
+
|
24
26
|
if jax.__version_info__ < (0, 4, 38):
|
25
27
|
from jax.core import Primitive
|
26
28
|
else:
|
@@ -77,7 +79,7 @@ def _heaviside_imp(x, dx):
|
|
77
79
|
|
78
80
|
|
79
81
|
def _heaviside_batching(args, axes):
|
80
|
-
return heaviside_p.bind(*args),
|
82
|
+
return heaviside_p.bind(*args), tuple(axes)
|
81
83
|
|
82
84
|
|
83
85
|
def _heaviside_jvp(primals, tangents):
|
@@ -97,7 +99,7 @@ ad.primitive_jvps[heaviside_p] = _heaviside_jvp
|
|
97
99
|
mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
|
98
100
|
|
99
101
|
|
100
|
-
class Surrogate(
|
102
|
+
class Surrogate(PrettyObject):
|
101
103
|
"""The base surrograte gradient function.
|
102
104
|
|
103
105
|
To customize a surrogate gradient function, you can inherit this class and
|
@@ -142,9 +144,20 @@ class Surrogate(object):
|
|
142
144
|
class Sigmoid(Surrogate):
|
143
145
|
"""Spike function with the sigmoid-shaped surrogate gradient.
|
144
146
|
|
147
|
+
This class implements a spiking neuron activation with a sigmoid-shaped
|
148
|
+
surrogate gradient for backpropagation. It can be used in spiking neural
|
149
|
+
networks to approximate the non-differentiable step function during training.
|
150
|
+
|
151
|
+
Parameters
|
152
|
+
----------
|
153
|
+
alpha : float, optional
|
154
|
+
A parameter controlling the steepness of the sigmoid curve in the
|
155
|
+
surrogate gradient. Higher values make the transition sharper.
|
156
|
+
Default is 4.0.
|
157
|
+
|
145
158
|
See Also
|
146
159
|
--------
|
147
|
-
sigmoid
|
160
|
+
sigmoid : Function version of this class.
|
148
161
|
|
149
162
|
"""
|
150
163
|
|
@@ -153,9 +166,33 @@ class Sigmoid(Surrogate):
|
|
153
166
|
self.alpha = alpha
|
154
167
|
|
155
168
|
def surrogate_fun(self, x):
|
169
|
+
"""Compute the surrogate function.
|
170
|
+
|
171
|
+
Parameters
|
172
|
+
----------
|
173
|
+
x : jax.Array
|
174
|
+
The input array.
|
175
|
+
|
176
|
+
Returns
|
177
|
+
-------
|
178
|
+
jax.Array
|
179
|
+
The output of the surrogate function.
|
180
|
+
"""
|
156
181
|
return sci.special.expit(self.alpha * x)
|
157
182
|
|
158
183
|
def surrogate_grad(self, x):
|
184
|
+
"""Compute the gradient of the surrogate function.
|
185
|
+
|
186
|
+
Parameters
|
187
|
+
----------
|
188
|
+
x : jax.Array
|
189
|
+
The input array.
|
190
|
+
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
jax.Array
|
194
|
+
The gradient of the surrogate function.
|
195
|
+
"""
|
159
196
|
sgax = sci.special.expit(x * self.alpha)
|
160
197
|
dx = (1. - sgax) * sgax * self.alpha
|
161
198
|
return dx
|
@@ -171,7 +208,12 @@ def sigmoid(
|
|
171
208
|
x: jax.Array,
|
172
209
|
alpha: float = 4.,
|
173
210
|
):
|
174
|
-
r"""
|
211
|
+
r"""
|
212
|
+
Compute a spike function with a sigmoid-shaped surrogate gradient.
|
213
|
+
|
214
|
+
This function implements a spiking neuron activation with a sigmoid-shaped
|
215
|
+
surrogate gradient for backpropagation. It can be used in spiking neural
|
216
|
+
networks to approximate the non-differentiable step function during training.
|
175
217
|
|
176
218
|
If `origin=False`, return the forward function:
|
177
219
|
|
@@ -210,16 +252,28 @@ def sigmoid(
|
|
210
252
|
|
211
253
|
Parameters
|
212
254
|
----------
|
213
|
-
x: jax.Array
|
214
|
-
|
215
|
-
alpha: float
|
216
|
-
|
217
|
-
|
255
|
+
x : jax.Array
|
256
|
+
The input array representing the neuron's membrane potential.
|
257
|
+
alpha : float, optional
|
258
|
+
A parameter controlling the steepness of the sigmoid curve in the
|
259
|
+
surrogate gradient. Higher values make the transition sharper.
|
260
|
+
Default is 4.0.
|
218
261
|
|
219
262
|
Returns
|
220
263
|
-------
|
221
|
-
|
222
|
-
|
264
|
+
jax.Array
|
265
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
266
|
+
representing the spiking state of each neuron.
|
267
|
+
|
268
|
+
Notes
|
269
|
+
-----
|
270
|
+
The forward pass uses a step function (1 for x >= 0, 0 for x < 0),
|
271
|
+
while the backward pass uses a sigmoid-shaped surrogate gradient for
|
272
|
+
smooth optimization.
|
273
|
+
|
274
|
+
The surrogate gradient is defined as:
|
275
|
+
g'(x) = alpha * (1 - sigmoid(alpha * x)) * sigmoid(alpha * x)
|
276
|
+
|
223
277
|
"""
|
224
278
|
return Sigmoid(alpha=alpha)(x)
|
225
279
|
|
@@ -238,11 +292,15 @@ class PiecewiseQuadratic(Surrogate):
|
|
238
292
|
self.alpha = alpha
|
239
293
|
|
240
294
|
def surrogate_fun(self, x):
|
241
|
-
z = jnp.where(
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
295
|
+
z = jnp.where(
|
296
|
+
x < -1 / self.alpha,
|
297
|
+
0.,
|
298
|
+
jnp.where(
|
299
|
+
x > 1 / self.alpha,
|
300
|
+
1.,
|
301
|
+
(-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5
|
302
|
+
)
|
303
|
+
)
|
246
304
|
return z
|
247
305
|
|
248
306
|
def surrogate_grad(self, x):
|
@@ -260,7 +318,12 @@ def piecewise_quadratic(
|
|
260
318
|
x: jax.Array,
|
261
319
|
alpha: float = 1.,
|
262
320
|
):
|
263
|
-
r"""
|
321
|
+
r"""
|
322
|
+
Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
|
323
|
+
|
324
|
+
This function implements a surrogate gradient method for spiking neural networks
|
325
|
+
using a piecewise quadratic function. It provides a differentiable approximation
|
326
|
+
of the step function used in the forward pass of spiking neurons.
|
264
327
|
|
265
328
|
If `origin=False`, computes the forward function:
|
266
329
|
|
@@ -306,18 +369,29 @@ def piecewise_quadratic(
|
|
306
369
|
>>> plt.legend()
|
307
370
|
>>> plt.show()
|
308
371
|
|
309
|
-
|
372
|
+
Parameters
|
310
373
|
----------
|
311
|
-
x: jax.Array
|
312
|
-
|
313
|
-
alpha: float
|
314
|
-
|
315
|
-
|
374
|
+
x : jax.Array
|
375
|
+
The input array representing the neuron's membrane potential.
|
376
|
+
alpha : float, optional
|
377
|
+
A parameter controlling the steepness of the surrogate gradient.
|
378
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
316
379
|
|
317
380
|
Returns
|
318
381
|
-------
|
319
|
-
|
320
|
-
|
382
|
+
jax.Array
|
383
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
384
|
+
representing the spiking state of each neuron.
|
385
|
+
|
386
|
+
Notes
|
387
|
+
-----
|
388
|
+
The function uses different computations for forward and backward passes:
|
389
|
+
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
390
|
+
- Backward: Piecewise quadratic function for smooth gradient
|
391
|
+
|
392
|
+
The surrogate gradient is defined as:
|
393
|
+
g'(x) = 0 if |x| > 1/alpha
|
394
|
+
-alpha^2|x| + alpha if |x| <= 1/alpha
|
321
395
|
|
322
396
|
References
|
323
397
|
----------
|
@@ -331,11 +405,22 @@ def piecewise_quadratic(
|
|
331
405
|
|
332
406
|
|
333
407
|
class PiecewiseExp(Surrogate):
|
334
|
-
"""
|
408
|
+
"""
|
409
|
+
Judge spiking state with a piecewise exponential function.
|
410
|
+
|
411
|
+
This class implements a surrogate gradient method for spiking neural networks
|
412
|
+
using a piecewise exponential function. It provides a differentiable approximation
|
413
|
+
of the step function used in the forward pass of spiking neurons.
|
414
|
+
|
415
|
+
Parameters
|
416
|
+
----------
|
417
|
+
alpha : float, optional
|
418
|
+
A parameter controlling the steepness of the surrogate gradient.
|
419
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
335
420
|
|
336
421
|
See Also
|
337
422
|
--------
|
338
|
-
piecewise_exp
|
423
|
+
piecewise_exp : Function version of this class.
|
339
424
|
"""
|
340
425
|
|
341
426
|
def __init__(self, alpha: float = 1.):
|
@@ -343,16 +428,62 @@ class PiecewiseExp(Surrogate):
|
|
343
428
|
self.alpha = alpha
|
344
429
|
|
345
430
|
def surrogate_grad(self, x):
|
431
|
+
"""
|
432
|
+
Compute the surrogate gradient.
|
433
|
+
|
434
|
+
Parameters
|
435
|
+
----------
|
436
|
+
x : jax.Array
|
437
|
+
The input array.
|
438
|
+
|
439
|
+
Returns
|
440
|
+
-------
|
441
|
+
jax.Array
|
442
|
+
The surrogate gradient.
|
443
|
+
"""
|
346
444
|
dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
|
347
445
|
return dx
|
348
446
|
|
349
447
|
def surrogate_fun(self, x):
|
350
|
-
|
448
|
+
"""
|
449
|
+
Compute the surrogate function.
|
450
|
+
|
451
|
+
Parameters
|
452
|
+
----------
|
453
|
+
x : jax.Array
|
454
|
+
The input array.
|
455
|
+
|
456
|
+
Returns
|
457
|
+
-------
|
458
|
+
jax.Array
|
459
|
+
The output of the surrogate function.
|
460
|
+
"""
|
461
|
+
return jnp.where(
|
462
|
+
x < 0,
|
463
|
+
jnp.exp(self.alpha * x) / 2,
|
464
|
+
1 - jnp.exp(-self.alpha * x) / 2
|
465
|
+
)
|
351
466
|
|
352
467
|
def __repr__(self):
|
468
|
+
"""
|
469
|
+
Return a string representation of the PiecewiseExp instance.
|
470
|
+
|
471
|
+
Returns
|
472
|
+
-------
|
473
|
+
str
|
474
|
+
A string representation of the instance.
|
475
|
+
"""
|
353
476
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
354
477
|
|
355
478
|
def __hash__(self):
|
479
|
+
"""
|
480
|
+
Compute a hash value for the PiecewiseExp instance.
|
481
|
+
|
482
|
+
Returns
|
483
|
+
-------
|
484
|
+
int
|
485
|
+
A hash value for the instance.
|
486
|
+
"""
|
356
487
|
return hash((self.__class__, self.alpha))
|
357
488
|
|
358
489
|
|
@@ -363,6 +494,10 @@ def piecewise_exp(
|
|
363
494
|
):
|
364
495
|
r"""Judge spiking state with a piecewise exponential function [1]_.
|
365
496
|
|
497
|
+
This function implements a surrogate gradient method for spiking neural networks
|
498
|
+
using a piecewise exponential function. It provides a differentiable approximation
|
499
|
+
of the step function used in the forward pass of spiking neurons.
|
500
|
+
|
366
501
|
If `origin=False`, computes the forward function:
|
367
502
|
|
368
503
|
.. math::
|
@@ -403,16 +538,26 @@ def piecewise_exp(
|
|
403
538
|
|
404
539
|
Parameters
|
405
540
|
----------
|
406
|
-
x: jax.Array
|
407
|
-
|
408
|
-
alpha: float
|
409
|
-
|
410
|
-
|
541
|
+
x : jax.Array
|
542
|
+
The input array representing the neuron's membrane potential.
|
543
|
+
alpha : float, optional
|
544
|
+
A parameter controlling the steepness of the surrogate gradient.
|
545
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
411
546
|
|
412
547
|
Returns
|
413
548
|
-------
|
414
|
-
|
415
|
-
|
549
|
+
jax.Array
|
550
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
551
|
+
representing the spiking state of each neuron.
|
552
|
+
|
553
|
+
Notes
|
554
|
+
-----
|
555
|
+
The function uses different computations for forward and backward passes:
|
556
|
+
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
557
|
+
- Backward: Piecewise exponential function for smooth gradient
|
558
|
+
|
559
|
+
The surrogate gradient is defined as:
|
560
|
+
g'(x) = (alpha / 2) * exp(-alpha * |x|)
|
416
561
|
|
417
562
|
References
|
418
563
|
----------
|
brainstate/util/_dict.py
CHANGED
@@ -24,11 +24,17 @@ import jax
|
|
24
24
|
|
25
25
|
from brainstate.typing import Filter, PathParts
|
26
26
|
from ._filter import to_predicate
|
27
|
-
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr,
|
27
|
+
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
|
28
28
|
from ._struct import dataclass
|
29
29
|
|
30
30
|
__all__ = [
|
31
|
-
'
|
31
|
+
'PrettyDict',
|
32
|
+
'NestedDict',
|
33
|
+
'FlattedDict',
|
34
|
+
'flat_mapping',
|
35
|
+
'nest_mapping',
|
36
|
+
'PrettyList',
|
37
|
+
'PrettyObject',
|
32
38
|
]
|
33
39
|
|
34
40
|
A = TypeVar('A')
|
@@ -40,6 +46,119 @@ ExtractValueFn = abc.Callable[[Any], Any]
|
|
40
46
|
SetValueFn = abc.Callable[[V, Any], V]
|
41
47
|
|
42
48
|
|
49
|
+
def _repr_object_general(node: PrettyDict):
|
50
|
+
"""
|
51
|
+
Generate a general representation of a PrettyDict object.
|
52
|
+
|
53
|
+
This function is used to create a pretty representation of a PrettyDict
|
54
|
+
object, which includes the type of the object and its value separator.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
node (PrettyDict): The PrettyDict object to be represented.
|
58
|
+
|
59
|
+
Yields:
|
60
|
+
PrettyType: A PrettyType object representing the type of the node,
|
61
|
+
with specified value separator, start, and end characters.
|
62
|
+
"""
|
63
|
+
yield PrettyType(type(node), value_sep='=', start='(', end=')')
|
64
|
+
|
65
|
+
|
66
|
+
def _repr_attribute_general(node):
|
67
|
+
"""
|
68
|
+
Generate a pretty representation of the attributes of a node.
|
69
|
+
|
70
|
+
This function iterates over the attributes of a given node and attempts
|
71
|
+
to generate a pretty representation for each attribute. It handles
|
72
|
+
conversion of lists and dictionaries to their pretty representation
|
73
|
+
counterparts and yields a PrettyAttr object for each attribute.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
node: The object whose attributes are to be represented.
|
77
|
+
|
78
|
+
Yields:
|
79
|
+
PrettyAttr: A PrettyAttr object representing the key and value of
|
80
|
+
each attribute in a pretty format.
|
81
|
+
"""
|
82
|
+
for k, v in vars(node).items():
|
83
|
+
try:
|
84
|
+
res = node.__pretty_repr_item__(k, v)
|
85
|
+
if res is None:
|
86
|
+
continue
|
87
|
+
k, v = res
|
88
|
+
except AttributeError:
|
89
|
+
pass
|
90
|
+
|
91
|
+
if k is None:
|
92
|
+
continue
|
93
|
+
|
94
|
+
# convert list to PrettyList
|
95
|
+
if isinstance(v, list):
|
96
|
+
v = PrettyList(v)
|
97
|
+
|
98
|
+
# convert dict to PrettyDict
|
99
|
+
if isinstance(v, dict):
|
100
|
+
v = PrettyDict(v)
|
101
|
+
|
102
|
+
# convert PrettyDict to NestedStateRepr
|
103
|
+
if isinstance(v, PrettyDict):
|
104
|
+
v = NestedStateRepr(v)
|
105
|
+
|
106
|
+
yield PrettyAttr(k, v)
|
107
|
+
|
108
|
+
|
109
|
+
class PrettyObject(PrettyRepr):
|
110
|
+
"""
|
111
|
+
A class for generating a pretty representation of a tree-like structure.
|
112
|
+
|
113
|
+
This class extends the PrettyRepr class to provide a mechanism for
|
114
|
+
generating a human-readable, pretty representation of tree-like data
|
115
|
+
structures. It utilizes custom functions to represent the object and
|
116
|
+
its attributes in a structured and visually appealing format.
|
117
|
+
|
118
|
+
Methods
|
119
|
+
-------
|
120
|
+
__pretty_repr__: Generates a sequence of pretty representation items
|
121
|
+
for the object.
|
122
|
+
__pretty_repr_item__: Returns a tuple of the key and value for pretty
|
123
|
+
representation of an item in the data structure.
|
124
|
+
"""
|
125
|
+
|
126
|
+
def __pretty_repr__(self):
|
127
|
+
"""
|
128
|
+
Generates a pretty representation of the object.
|
129
|
+
|
130
|
+
This method yields a sequence of pretty representation items for the object,
|
131
|
+
using specified functions to represent the object and its attributes.
|
132
|
+
|
133
|
+
Yields:
|
134
|
+
Pretty representation items generated by `yield_unique_pretty_repr_items`.
|
135
|
+
"""
|
136
|
+
yield from yield_unique_pretty_repr_items(
|
137
|
+
self,
|
138
|
+
repr_object=_repr_object_general,
|
139
|
+
repr_attr=_repr_attribute_general,
|
140
|
+
)
|
141
|
+
|
142
|
+
def __pretty_repr_item__(self, k, v):
|
143
|
+
"""
|
144
|
+
Returns a tuple of the key and value for pretty representation.
|
145
|
+
|
146
|
+
This method is used to generate a pretty representation of an item
|
147
|
+
in a data structure, typically for debugging or logging purposes.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
k: The key of the item.
|
151
|
+
v: The value of the item.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
A tuple containing the key and value.
|
155
|
+
"""
|
156
|
+
return k, v
|
157
|
+
|
158
|
+
|
159
|
+
PrettyReprTree = PrettyObject
|
160
|
+
|
161
|
+
|
43
162
|
# the empty node is a struct.dataclass to be compatible with JAX.
|
44
163
|
@dataclass
|
45
164
|
class _EmptyNode:
|
@@ -213,10 +332,10 @@ class PrettyDict(dict, PrettyRepr):
|
|
213
332
|
|
214
333
|
def __repr__(self) -> str:
|
215
334
|
# repr the individual object with the pretty representation
|
216
|
-
return
|
335
|
+
return pretty_repr_object(self)
|
217
336
|
|
218
337
|
def __pretty_repr__(self):
|
219
|
-
yield from
|
338
|
+
yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
|
220
339
|
|
221
340
|
def split(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
|
222
341
|
raise NotImplementedError
|
@@ -237,15 +356,20 @@ class PrettyDict(dict, PrettyRepr):
|
|
237
356
|
|
238
357
|
|
239
358
|
def _default_repr_object(node: PrettyDict):
|
240
|
-
yield PrettyType(
|
359
|
+
yield PrettyType('', value_sep=': ', start='{', end='}')
|
241
360
|
|
242
361
|
|
243
|
-
def _default_repr_attr(node
|
362
|
+
def _default_repr_attr(node):
|
244
363
|
for k, v in node.items():
|
364
|
+
if isinstance(v, list):
|
365
|
+
v = PrettyList(v)
|
366
|
+
|
245
367
|
if isinstance(v, dict):
|
246
368
|
v = PrettyDict(v)
|
369
|
+
|
247
370
|
if isinstance(v, PrettyDict):
|
248
371
|
v = NestedStateRepr(v)
|
372
|
+
|
249
373
|
yield PrettyAttr(repr(k), v)
|
250
374
|
|
251
375
|
|
@@ -735,3 +859,36 @@ def _flat_unflatten(
|
|
735
859
|
jax.tree_util.register_pytree_with_keys(FlattedDict,
|
736
860
|
_nest_flatten_with_keys,
|
737
861
|
_flat_unflatten) # type: ignore[arg-type]
|
862
|
+
|
863
|
+
|
864
|
+
@jax.tree_util.register_pytree_node_class
|
865
|
+
class PrettyList(list, PrettyRepr):
|
866
|
+
__module__ = 'brainstate.util'
|
867
|
+
|
868
|
+
def __pretty_repr__(self):
|
869
|
+
yield from yield_unique_pretty_repr_items(self, _list_repr_object, _list_repr_attr)
|
870
|
+
|
871
|
+
def __repr__(self):
|
872
|
+
return pretty_repr_object(self)
|
873
|
+
|
874
|
+
def tree_flatten(self):
|
875
|
+
return list(self), ()
|
876
|
+
|
877
|
+
@classmethod
|
878
|
+
def tree_unflatten(cls, aux_data, children):
|
879
|
+
return cls(children)
|
880
|
+
|
881
|
+
|
882
|
+
def _list_repr_attr(node: PrettyList):
|
883
|
+
for v in node:
|
884
|
+
if isinstance(v, list):
|
885
|
+
v = PrettyList(v)
|
886
|
+
if isinstance(v, dict):
|
887
|
+
v = PrettyDict(v)
|
888
|
+
if isinstance(v, PrettyDict):
|
889
|
+
v = NestedStateRepr(v)
|
890
|
+
yield PrettyAttr('', v)
|
891
|
+
|
892
|
+
|
893
|
+
def _list_repr_object(node: PrettyDict):
|
894
|
+
yield PrettyType('', value_sep='', start='[', end=']')
|
brainstate/util/_pretty_repr.py
CHANGED
@@ -21,9 +21,10 @@ import dataclasses
|
|
21
21
|
import threading
|
22
22
|
from abc import ABC, abstractmethod
|
23
23
|
from functools import partial
|
24
|
-
from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
|
24
|
+
from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional, Sequence
|
25
25
|
|
26
26
|
__all__ = [
|
27
|
+
'yield_unique_pretty_repr_items',
|
27
28
|
'PrettyType',
|
28
29
|
'PrettyAttr',
|
29
30
|
'PrettyRepr',
|
@@ -80,10 +81,37 @@ class PrettyRepr(ABC):
|
|
80
81
|
|
81
82
|
def __repr__(self) -> str:
|
82
83
|
# repr the individual object with the pretty representation
|
83
|
-
return
|
84
|
+
return pretty_repr_object(self)
|
84
85
|
|
85
86
|
|
86
|
-
def
|
87
|
+
def pretty_repr_elem(obj: PrettyType, elem: Any) -> str:
|
88
|
+
"""
|
89
|
+
Constructs a string representation of a single element within a pretty representation.
|
90
|
+
|
91
|
+
This function takes a `PrettyType` object and an element, which must be an instance
|
92
|
+
of `PrettyAttr`, and generates a formatted string that represents the element. The
|
93
|
+
formatting is based on the configuration provided by the `PrettyType` object.
|
94
|
+
|
95
|
+
Parameters
|
96
|
+
----------
|
97
|
+
obj : PrettyType
|
98
|
+
The configuration object that defines how the element should be formatted.
|
99
|
+
It includes details such as indentation, separators, and surrounding characters.
|
100
|
+
elem : Any
|
101
|
+
The element to be represented. It must be an instance of `PrettyAttr`, which
|
102
|
+
contains the key and value to be formatted.
|
103
|
+
|
104
|
+
Returns
|
105
|
+
-------
|
106
|
+
str
|
107
|
+
A string that represents the element in a formatted manner, adhering to the
|
108
|
+
configuration specified by the `PrettyType` object.
|
109
|
+
|
110
|
+
Raises
|
111
|
+
------
|
112
|
+
TypeError
|
113
|
+
If the provided element is not an instance of `PrettyAttr`.
|
114
|
+
"""
|
87
115
|
if not isinstance(elem, PrettyAttr):
|
88
116
|
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
|
89
117
|
|
@@ -93,9 +121,32 @@ def _repr_elem(obj: PrettyType, elem: Any) -> str:
|
|
93
121
|
return f'{obj.elem_indent}{elem.start}{elem.key}{obj.value_sep}{value}{elem.end}'
|
94
122
|
|
95
123
|
|
96
|
-
def
|
124
|
+
def pretty_repr_object(obj: PrettyRepr) -> str:
|
97
125
|
"""
|
98
|
-
|
126
|
+
Generates a pretty string representation of an object that implements the PrettyRepr interface.
|
127
|
+
|
128
|
+
This function utilizes the __pretty_repr__ method of the PrettyRepr interface to obtain
|
129
|
+
a structured representation of the object, which includes both the type and attributes
|
130
|
+
of the object in a human-readable format.
|
131
|
+
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
obj : PrettyRepr
|
135
|
+
The object for which the pretty representation is to be generated. The object must
|
136
|
+
implement the PrettyRepr interface.
|
137
|
+
|
138
|
+
Returns
|
139
|
+
-------
|
140
|
+
str
|
141
|
+
A string that represents the object in a pretty format, including its type and attributes.
|
142
|
+
The format is determined by the PrettyType and PrettyAttr instances yielded by the
|
143
|
+
__pretty_repr__ method of the object.
|
144
|
+
|
145
|
+
Raises
|
146
|
+
------
|
147
|
+
TypeError
|
148
|
+
If the provided object does not implement the PrettyRepr interface or if the first item
|
149
|
+
yielded by the __pretty_repr__ method is not an instance of PrettyType.
|
99
150
|
"""
|
100
151
|
if not isinstance(obj, PrettyRepr):
|
101
152
|
raise TypeError(f'Object {obj!r} is not representable')
|
@@ -108,7 +159,7 @@ def get_repr(obj: PrettyRepr) -> str:
|
|
108
159
|
raise TypeError(f'First item must be PrettyType, got {type(obj_repr).__name__}')
|
109
160
|
|
110
161
|
# repr attributes
|
111
|
-
elem_reprs = tuple(map(partial(
|
162
|
+
elem_reprs = tuple(map(partial(pretty_repr_elem, obj_repr), iterator))
|
112
163
|
elems = ',\n'.join(elem_reprs)
|
113
164
|
if elems:
|
114
165
|
elems = '\n' + elems + '\n'
|
@@ -140,9 +191,10 @@ class PrettyMapping(PrettyRepr):
|
|
140
191
|
Pretty representation of a mapping.
|
141
192
|
"""
|
142
193
|
mapping: Mapping
|
194
|
+
type_name: str = ''
|
143
195
|
|
144
196
|
def __pretty_repr__(self):
|
145
|
-
yield PrettyType(type=
|
197
|
+
yield PrettyType(type=self.type_name, value_sep=': ', start='{', end='}')
|
146
198
|
|
147
199
|
for key, value in self.mapping.items():
|
148
200
|
yield PrettyAttr(repr(key), value)
|
@@ -150,7 +202,20 @@ class PrettyMapping(PrettyRepr):
|
|
150
202
|
|
151
203
|
@dataclasses.dataclass
|
152
204
|
class PrettyReprContext(threading.local):
|
153
|
-
|
205
|
+
"""
|
206
|
+
A thread-local context for managing the state of pretty representation.
|
207
|
+
|
208
|
+
This class is used to keep track of objects that have been seen during
|
209
|
+
the generation of pretty representations, preventing infinite recursion
|
210
|
+
in cases of circular references.
|
211
|
+
|
212
|
+
Attributes
|
213
|
+
----------
|
214
|
+
seen_modules_repr : dict[int, Any] | None
|
215
|
+
A dictionary mapping object IDs to objects that have been seen
|
216
|
+
during the pretty representation process. This is used to avoid
|
217
|
+
representing the same object multiple times.
|
218
|
+
"""
|
154
219
|
seen_modules_repr: dict[int, Any] | None = None
|
155
220
|
|
156
221
|
|
@@ -158,23 +223,80 @@ CONTEXT = PrettyReprContext()
|
|
158
223
|
|
159
224
|
|
160
225
|
def _default_repr_object(node):
|
226
|
+
"""
|
227
|
+
Generates a default pretty representation for an object.
|
228
|
+
|
229
|
+
This function yields a `PrettyType` instance that represents the type
|
230
|
+
of the given object. It is used as a default method for representing
|
231
|
+
objects when no custom representation function is provided.
|
232
|
+
|
233
|
+
Parameters
|
234
|
+
----------
|
235
|
+
node : Any
|
236
|
+
The object for which the pretty representation is to be generated.
|
237
|
+
|
238
|
+
Yields
|
239
|
+
------
|
240
|
+
PrettyType
|
241
|
+
An instance of `PrettyType` that contains the type information of
|
242
|
+
the object.
|
243
|
+
"""
|
161
244
|
yield PrettyType(type=type(node))
|
162
245
|
|
163
246
|
|
164
247
|
def _default_repr_attr(node):
|
248
|
+
"""
|
249
|
+
Generates a default pretty representation for the attributes of an object.
|
250
|
+
|
251
|
+
This function iterates over the attributes of the given object and yields
|
252
|
+
a `PrettyAttr` instance for each attribute that does not start with an
|
253
|
+
underscore. The `PrettyAttr` instances contain the attribute name and its
|
254
|
+
string representation.
|
255
|
+
|
256
|
+
Parameters
|
257
|
+
----------
|
258
|
+
node : Any
|
259
|
+
The object whose attributes are to be represented.
|
260
|
+
|
261
|
+
Yields
|
262
|
+
------
|
263
|
+
PrettyAttr
|
264
|
+
An instance of `PrettyAttr` for each non-private attribute of the object,
|
265
|
+
containing the attribute name and its string representation.
|
266
|
+
"""
|
165
267
|
for name, value in vars(node).items():
|
166
268
|
if name.startswith('_'):
|
167
269
|
continue
|
168
270
|
yield PrettyAttr(name, repr(value))
|
169
271
|
|
170
272
|
|
171
|
-
def
|
273
|
+
def yield_unique_pretty_repr_items(
|
172
274
|
node,
|
173
275
|
repr_object: Optional[Callable] = None,
|
174
276
|
repr_attr: Optional[Callable] = None
|
175
277
|
):
|
176
278
|
"""
|
177
|
-
|
279
|
+
Generates a pretty representation of an object while avoiding duplicate representations.
|
280
|
+
|
281
|
+
This function is designed to yield a structured representation of an object,
|
282
|
+
using custom or default methods for representing the object itself and its attributes.
|
283
|
+
It ensures that each object is only represented once to prevent infinite recursion
|
284
|
+
in cases of circular references.
|
285
|
+
|
286
|
+
Parameters:
|
287
|
+
node : Any
|
288
|
+
The object to be represented.
|
289
|
+
repr_object : Optional[Callable], optional
|
290
|
+
A callable that yields the representation of the object itself.
|
291
|
+
If not provided, a default representation function is used.
|
292
|
+
repr_attr : Optional[Callable], optional
|
293
|
+
A callable that yields the representation of the object's attributes.
|
294
|
+
If not provided, a default attribute representation function is used.
|
295
|
+
|
296
|
+
Yields:
|
297
|
+
Union[PrettyType, PrettyAttr]
|
298
|
+
The pretty representation of the object and its attributes,
|
299
|
+
avoiding duplicates by tracking seen objects.
|
178
300
|
"""
|
179
301
|
if repr_object is None:
|
180
302
|
repr_object = _default_repr_object
|
@@ -206,3 +328,4 @@ def pretty_repr_avoid_duplicate(
|
|
206
328
|
finally:
|
207
329
|
if clear_seen:
|
208
330
|
CONTEXT.seen_modules_repr = None
|
331
|
+
|
{brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250212
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,12 +1,12 @@
|
|
1
1
|
brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
|
2
|
-
brainstate/_state.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=aM2UTfFGvfXfM-pCLvufhgyuuLBGfogBYsz7ZCU8P7Q,28588
|
3
3
|
brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
5
|
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
6
6
|
brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
|
-
brainstate/surrogate.py,sha256=
|
9
|
+
brainstate/surrogate.py,sha256=xS4UG4LHKUJdHqwZ5-p-9Y2jWXMa-ssdZJCMiW9zi5k,53540
|
10
10
|
brainstate/transform.py,sha256=cxbymTlJ6uHvJWEEYXzFUkAySs_TbUTHakt0NQgWJ3s,808
|
11
11
|
brainstate/typing.py,sha256=Qh-LBzm6oG4rSXv4V5qB8SNYcoOR7bASoK_iQxnlafk,10467
|
12
12
|
brainstate/augment/__init__.py,sha256=zGPq1eTB_56GRCNC9TiPLKTw07PA2O0OCi7bgjYIrY4,1193
|
@@ -30,7 +30,7 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
|
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=Rr36U0s8qow1A4KJYXkALX10Rm2pkSYF2j_1eiSuSGI,33292
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
36
36
|
brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
|
@@ -42,7 +42,7 @@ brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJ
|
|
42
42
|
brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
|
43
43
|
brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
|
44
44
|
brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
|
45
|
-
brainstate/graph/_graph_node.py,sha256=
|
45
|
+
brainstate/graph/_graph_node.py,sha256=XwzOuaZG9x4eZknQjzJoTnnYAy7wcKD5Vox1VkYr8GM,8345
|
46
46
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
47
47
|
brainstate/graph/_graph_operation.py,sha256=UtBNP7hvxa-5i99LQJStXbFhUbl3icdfTq1oF4MeH1g,64106
|
48
48
|
brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
|
@@ -57,7 +57,7 @@ brainstate/nn/__init__.py,sha256=rxURT8J1XfBn3Vh3Dx_WzVADWn9zVriIty5KZEG-x6o,162
|
|
57
57
|
brainstate/nn/_collective_ops.py,sha256=sSjIIs1MvZA30XFFmK7iL1D_sCeh7hFd3PanCH6kgZo,6779
|
58
58
|
brainstate/nn/_exp_euler.py,sha256=yjkfSllFxGWKEAlHo5AzBizzkFj6FEVDKmFV6E2g214,3521
|
59
59
|
brainstate/nn/_exp_euler_test.py,sha256=clwRD8QR71k1jn6NrACMDEUcFMh0J9RTosoPnlYWUkw,1242
|
60
|
-
brainstate/nn/_module.py,sha256=
|
60
|
+
brainstate/nn/_module.py,sha256=UzsnaTDh5F6Z8B7ou4RXmTdAWXbNkjf03bYP0kF3_fE,10872
|
61
61
|
brainstate/nn/_module_test.py,sha256=V4ZhiY_zYPvArkB2eeOTtZcgQrtlRyXKMbS1AJH4vC8,8893
|
62
62
|
brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,14738
|
63
63
|
brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
|
@@ -84,7 +84,7 @@ brainstate/nn/_elementwise/_dropout_test.py,sha256=k6aB5v8RYMoV5w8UV9UNSFhaQTV7w
|
|
84
84
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
85
85
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
86
86
|
brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
|
87
|
-
brainstate/nn/_interaction/_conv.py,sha256=
|
87
|
+
brainstate/nn/_interaction/_conv.py,sha256=eKhABWtG3QlOy7TPY9yoQjP3liBh9bb8X5Wns3_YUUQ,18499
|
88
88
|
brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
|
89
89
|
brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
|
90
90
|
brainstate/nn/_interaction/_linear.py,sha256=EnkOk1oE79rvRIjU6HBllxUpVOEcQQCj4vtavo9AJjI,14767
|
@@ -109,16 +109,16 @@ brainstate/random/_rand_state.py,sha256=nuoQ8GU1MfJPRNN-ZmRQsggVjoyPhaEdZmwM7_4-
|
|
109
109
|
brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
|
110
110
|
brainstate/util/__init__.py,sha256=-FWEuSKXG3mWxYphGFAy3UEuVe39lFs1GruluzdXDoI,1502
|
111
111
|
brainstate/util/_caller.py,sha256=T3bzu7-09r-6EOrU6Muca_aMXSQua_X2lXjEqb-w39w,2782
|
112
|
-
brainstate/util/_dict.py,sha256=
|
112
|
+
brainstate/util/_dict.py,sha256=qPUbqjRVHUvVHhSWBPojx_srsh6-iy1k5oPMn1DdrnQ,30880
|
113
113
|
brainstate/util/_dict_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
|
114
114
|
brainstate/util/_error.py,sha256=eyZ8PGFixqe2K5OEfjSDzI-2tU0ieYQoUpBP7yStlPQ,878
|
115
115
|
brainstate/util/_filter.py,sha256=1-bvFHdjeehvXeHTrCEp8xr25lopKe8d3XZGCNegq0s,4970
|
116
116
|
brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14762
|
117
|
-
brainstate/util/_pretty_repr.py,sha256
|
117
|
+
brainstate/util/_pretty_repr.py,sha256=-TZPIgfTLB-Eg7rgT7KAkV1r-HX0q6nCgKDKA7Qdsw4,10577
|
118
118
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
119
119
|
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
120
|
-
brainstate-0.1.0.
|
121
|
-
brainstate-0.1.0.
|
122
|
-
brainstate-0.1.0.
|
123
|
-
brainstate-0.1.0.
|
124
|
-
brainstate-0.1.0.
|
120
|
+
brainstate-0.1.0.post20250212.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
121
|
+
brainstate-0.1.0.post20250212.dist-info/METADATA,sha256=OVPO4wr0e0j_Lvk_OQKTpTdNUbOGsFt_BW_qKakO8xE,3585
|
122
|
+
brainstate-0.1.0.post20250212.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
123
|
+
brainstate-0.1.0.post20250212.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
124
|
+
brainstate-0.1.0.post20250212.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250210.dist-info → brainstate-0.1.0.post20250212.dist-info}/top_level.txt
RENAMED
File without changes
|