brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/graph/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -14,16 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from .
|
18
|
-
from .
|
19
|
-
|
20
|
-
treefy_split, treefy_merge, iter_leaf, iter_node, clone, graphdef,
|
21
|
-
call, RefMap, GraphDef, NodeRef, NodeDef
|
22
|
-
)
|
17
|
+
from ._node import Node
|
18
|
+
from ._operation import *
|
19
|
+
from ._operation import __all__ as operation_all
|
23
20
|
|
24
|
-
__all__ = [
|
25
|
-
|
26
|
-
'pop_states', 'nodes', 'states', 'treefy_states', 'update_states', 'flatten', 'unflatten',
|
27
|
-
'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef',
|
28
|
-
'call', 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef',
|
29
|
-
]
|
21
|
+
__all__ = ['Node'] + operation_all
|
22
|
+
del operation_all
|
@@ -0,0 +1,240 @@
|
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
from abc import ABCMeta
|
19
|
+
from copy import deepcopy
|
20
|
+
from typing import Any, Type, TypeVar, Tuple, TYPE_CHECKING
|
21
|
+
|
22
|
+
from brainstate._state import State, TreefyState
|
23
|
+
from brainstate.typing import Key
|
24
|
+
from brainstate.util._pretty_pytree import PrettyObject
|
25
|
+
from ._operation import register_graph_node_type, treefy_split, treefy_merge
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'Node',
|
29
|
+
]
|
30
|
+
|
31
|
+
G = TypeVar('G', bound='Node')
|
32
|
+
A = TypeVar('A')
|
33
|
+
|
34
|
+
|
35
|
+
class GraphNodeMeta(ABCMeta):
|
36
|
+
if not TYPE_CHECKING:
|
37
|
+
def __call__(cls, *args, **kwargs) -> Any:
|
38
|
+
node = cls.__new__(cls, *args, **kwargs)
|
39
|
+
node.__init__(*args, **kwargs)
|
40
|
+
return node
|
41
|
+
|
42
|
+
|
43
|
+
class Node(PrettyObject, metaclass=GraphNodeMeta):
|
44
|
+
"""
|
45
|
+
Base class for all graph nodes in the BrainState framework.
|
46
|
+
|
47
|
+
This class serves as the foundation for creating computational graph nodes
|
48
|
+
that can be used in neural network architectures and other graph-based
|
49
|
+
computations.
|
50
|
+
|
51
|
+
Attributes
|
52
|
+
----------
|
53
|
+
graph_invisible_attrs : tuple
|
54
|
+
Tuple of attribute names that should be excluded from graph
|
55
|
+
serialization and flattening operations.
|
56
|
+
|
57
|
+
Methods
|
58
|
+
-------
|
59
|
+
__deepcopy__(memo=None)
|
60
|
+
Creates a deep copy of the node preserving its graph structure
|
61
|
+
and state.
|
62
|
+
|
63
|
+
Notes
|
64
|
+
-----
|
65
|
+
The class provides the following features:
|
66
|
+
|
67
|
+
- Automatic registration with the graph system via metaclass
|
68
|
+
- Deep copy support for creating independent node instances
|
69
|
+
- Pretty printing for better debugging and visualization
|
70
|
+
- State management integration with TreefyState
|
71
|
+
- Attribute visibility control via graph_invisible_attrs
|
72
|
+
|
73
|
+
Examples
|
74
|
+
--------
|
75
|
+
.. code-block:: python
|
76
|
+
|
77
|
+
>>> from copy import deepcopy
|
78
|
+
>>> class MyNode(Node):
|
79
|
+
... def __init__(self, value):
|
80
|
+
... self.value = value
|
81
|
+
>>> node = MyNode(10)
|
82
|
+
>>> copied_node = deepcopy(node)
|
83
|
+
>>> print(node.value)
|
84
|
+
10
|
85
|
+
"""
|
86
|
+
__module__ = 'brainstate.graph'
|
87
|
+
|
88
|
+
graph_invisible_attrs = ()
|
89
|
+
|
90
|
+
def __init_subclass__(cls) -> None:
|
91
|
+
super().__init_subclass__()
|
92
|
+
|
93
|
+
register_graph_node_type(
|
94
|
+
type=cls,
|
95
|
+
flatten=_node_flatten,
|
96
|
+
set_key=_node_set_key,
|
97
|
+
pop_key=_node_pop_key,
|
98
|
+
create_empty=_node_create_empty,
|
99
|
+
clear=_node_clear,
|
100
|
+
)
|
101
|
+
|
102
|
+
def __deepcopy__(self: G, memo=None) -> G:
|
103
|
+
graphdef, state = treefy_split(self)
|
104
|
+
graphdef = deepcopy(graphdef)
|
105
|
+
state = deepcopy(state)
|
106
|
+
return treefy_merge(graphdef, state)
|
107
|
+
|
108
|
+
|
109
|
+
# -------------------------------
|
110
|
+
# Graph Definition
|
111
|
+
# -------------------------------
|
112
|
+
|
113
|
+
|
114
|
+
def _node_flatten(node: Node) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[Type]]:
|
115
|
+
"""
|
116
|
+
Flatten a node into its constituent parts for serialization.
|
117
|
+
|
118
|
+
Parameters
|
119
|
+
----------
|
120
|
+
node : Node
|
121
|
+
The Node instance to flatten.
|
122
|
+
|
123
|
+
Returns
|
124
|
+
-------
|
125
|
+
tuple
|
126
|
+
A tuple containing:
|
127
|
+
- Sorted list of (key, value) pairs for visible attributes
|
128
|
+
- Tuple containing the node's type
|
129
|
+
"""
|
130
|
+
graph_invisible_attrs = getattr(node, 'graph_invisible_attrs', ())
|
131
|
+
# graph_invisible_attrs = tuple(graph_invisible_attrs) + ('_trace_state',)
|
132
|
+
nodes = sorted(
|
133
|
+
(key, value) for key, value in vars(node).items()
|
134
|
+
if (key not in graph_invisible_attrs)
|
135
|
+
)
|
136
|
+
return nodes, (type(node),)
|
137
|
+
|
138
|
+
|
139
|
+
def _node_set_key(node: Node, key: Key, value: Any) -> None:
|
140
|
+
"""
|
141
|
+
Set an attribute on a node with special handling for State objects.
|
142
|
+
|
143
|
+
Parameters
|
144
|
+
----------
|
145
|
+
node : Node
|
146
|
+
The Node instance to modify.
|
147
|
+
key : Key
|
148
|
+
The attribute name to set.
|
149
|
+
value : Any
|
150
|
+
The value to set.
|
151
|
+
|
152
|
+
Raises
|
153
|
+
------
|
154
|
+
KeyError
|
155
|
+
If the key is not a string.
|
156
|
+
|
157
|
+
Notes
|
158
|
+
-----
|
159
|
+
If the attribute already exists as a State object and the new value
|
160
|
+
is a TreefyState, the state is updated via reference rather than
|
161
|
+
replaced.
|
162
|
+
"""
|
163
|
+
if not isinstance(key, str):
|
164
|
+
raise KeyError(f'Invalid key: {key!r}')
|
165
|
+
elif (
|
166
|
+
hasattr(node, key)
|
167
|
+
and isinstance(state := getattr(node, key), State)
|
168
|
+
and isinstance(value, TreefyState)
|
169
|
+
):
|
170
|
+
state.update_from_ref(value)
|
171
|
+
else:
|
172
|
+
setattr(node, key, value)
|
173
|
+
|
174
|
+
|
175
|
+
def _node_pop_key(node: Node, key: Key) -> Any:
|
176
|
+
"""
|
177
|
+
Remove and return an attribute from a node.
|
178
|
+
|
179
|
+
Parameters
|
180
|
+
----------
|
181
|
+
node : Node
|
182
|
+
The Node instance to modify.
|
183
|
+
key : Key
|
184
|
+
The attribute name to remove.
|
185
|
+
|
186
|
+
Returns
|
187
|
+
-------
|
188
|
+
Any
|
189
|
+
The value of the removed attribute.
|
190
|
+
|
191
|
+
Raises
|
192
|
+
------
|
193
|
+
KeyError
|
194
|
+
If the key is not a string.
|
195
|
+
"""
|
196
|
+
if not isinstance(key, str):
|
197
|
+
raise KeyError(f'Invalid key: {key!r}')
|
198
|
+
return vars(node).pop(key)
|
199
|
+
|
200
|
+
|
201
|
+
def _node_create_empty(static: tuple[Type[G], ...]) -> G:
|
202
|
+
"""
|
203
|
+
Create an empty node instance without calling __init__.
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
static : tuple[Type[G], ...]
|
208
|
+
Tuple containing the node type to instantiate.
|
209
|
+
|
210
|
+
Returns
|
211
|
+
-------
|
212
|
+
G
|
213
|
+
A new uninitialized instance of the node type.
|
214
|
+
|
215
|
+
Notes
|
216
|
+
-----
|
217
|
+
This function is used internally by the graph system to create
|
218
|
+
nodes without invoking their initialization logic.
|
219
|
+
"""
|
220
|
+
node_type, = static
|
221
|
+
node = object.__new__(node_type)
|
222
|
+
return node
|
223
|
+
|
224
|
+
|
225
|
+
def _node_clear(node: Node) -> None:
|
226
|
+
"""
|
227
|
+
Clear all attributes from a node.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
node : Node
|
232
|
+
The Node instance to clear.
|
233
|
+
|
234
|
+
Notes
|
235
|
+
-----
|
236
|
+
This removes all attributes from the node's instance dictionary,
|
237
|
+
effectively resetting it to an empty state.
|
238
|
+
"""
|
239
|
+
module_vars = vars(node)
|
240
|
+
module_vars.clear()
|