brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
__all__ = [
|
18
18
|
'BrainStateError',
|
19
|
-
'
|
19
|
+
'BatchAxisError',
|
20
20
|
]
|
21
21
|
|
22
22
|
|
@@ -27,29 +27,19 @@ class BrainStateError(Exception):
|
|
27
27
|
This exception is raised when a BrainState-specific error occurs during
|
28
28
|
the execution of the program. It serves as a base class for more specific
|
29
29
|
BrainState exceptions.
|
30
|
-
|
31
|
-
Attributes:
|
32
|
-
Inherits all attributes from the built-in Exception class.
|
33
|
-
|
34
|
-
Usage::
|
35
|
-
|
36
|
-
raise BrainStateError("A BrainState-specific error occurred.")
|
37
30
|
"""
|
38
31
|
pass
|
39
32
|
|
40
33
|
|
41
|
-
class
|
34
|
+
class BatchAxisError(BrainStateError):
|
42
35
|
"""
|
43
|
-
|
36
|
+
Exception raised for errors related to batch axis operations.
|
44
37
|
|
45
|
-
This exception is
|
46
|
-
|
38
|
+
This custom exception is used to indicate errors that occur during
|
39
|
+
batch processing or vectorization operations, particularly in the
|
40
|
+
context of state management in the BrainState framework.
|
47
41
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
Usage::
|
52
|
-
|
53
|
-
raise TraceContextError("An error occurred while handling trace context.")
|
42
|
+
Inherits from:
|
43
|
+
BrainStateError: The base error class for BrainState-related exceptions.
|
54
44
|
"""
|
55
|
-
|
45
|
+
__module__ = 'brainstate.transform'
|
brainstate/_state.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -112,9 +112,12 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
|
|
112
112
|
If you want to check the tree structure of the value once the new value is assigned,
|
113
113
|
you can use this context manager.
|
114
114
|
|
115
|
-
|
115
|
+
Examples
|
116
|
+
--------
|
117
|
+
|
118
|
+
.. code-block:: python
|
116
119
|
|
117
|
-
>>> import brainstate
|
120
|
+
>>> import brainstate
|
118
121
|
>>> import jax.numpy as jnp
|
119
122
|
>>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
120
123
|
>>> with brainstate.check_state_value_tree():
|
@@ -158,10 +161,13 @@ def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
|
|
158
161
|
"""
|
159
162
|
The context manager to check whether the state is valid to trace.
|
160
163
|
|
161
|
-
Example
|
164
|
+
Example
|
165
|
+
-------
|
166
|
+
|
167
|
+
.. code-block:: python
|
162
168
|
|
163
169
|
>>> import jax
|
164
|
-
>>> import brainstate
|
170
|
+
>>> import brainstate
|
165
171
|
>>> import jax.numpy as jnp
|
166
172
|
>>>
|
167
173
|
>>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
@@ -213,7 +219,11 @@ class State(Generic[A], PrettyObject):
|
|
213
219
|
name (Optional[str]): An optional name for the state.
|
214
220
|
**metadata: Additional metadata to be stored with the state.
|
215
221
|
|
216
|
-
Example
|
222
|
+
Example
|
223
|
+
-------
|
224
|
+
|
225
|
+
.. code-block:: python
|
226
|
+
|
217
227
|
>>> class MyState(State):
|
218
228
|
... pass
|
219
229
|
>>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
|
@@ -910,24 +920,6 @@ class StateTraceStack(Generic[A]):
|
|
910
920
|
The class is generic over type A, allowing for type-safe usage with
|
911
921
|
different types of State objects.
|
912
922
|
|
913
|
-
Attributes:
|
914
|
-
states (List[State]): A list of all State objects encountered during tracing.
|
915
|
-
been_writen (List[bool]): A parallel list to states, indicating whether each state has been written to.
|
916
|
-
_state_id_index (dict): A dictionary mapping state ids to their index in the states list.
|
917
|
-
_original_state_values (List): A list of the original values of all states when first encountered.
|
918
|
-
_jax_trace_new_arg (Callable): A function used to transform state values during tracing.
|
919
|
-
|
920
|
-
Methods:
|
921
|
-
__enter__: Enters a new tracing context.
|
922
|
-
__exit__: Exits the current tracing context.
|
923
|
-
read_its_value: Records a read operation on a state.
|
924
|
-
write_its_value: Records a write operation on a state.
|
925
|
-
get_state_values: Retrieves the current values of all traced states.
|
926
|
-
recovery_original_values: Restores all states to their original values.
|
927
|
-
merge: Merges multiple ``StateTraceStack`` instances.
|
928
|
-
get_read_states: Retrieves states that were read during tracing.
|
929
|
-
get_read_state_values: Retrieves values of states that were read during tracing.
|
930
|
-
|
931
923
|
The ``StateTraceStack`` is a crucial component in implementing state-based
|
932
924
|
computations and is particularly useful in scenarios involving automatic
|
933
925
|
differentiation or other forms of program transformation.
|
@@ -992,7 +984,7 @@ class StateTraceStack(Generic[A]):
|
|
992
984
|
"""
|
993
985
|
if self._jax_trace_new_arg is not None:
|
994
986
|
# internal use
|
995
|
-
state._value =
|
987
|
+
state._value = self._jax_trace_new_arg(state)
|
996
988
|
|
997
989
|
def __enter__(self) -> 'StateTraceStack':
|
998
990
|
TRACE_CONTEXT.state_stack.append(self)
|
@@ -1266,6 +1258,28 @@ class StateTraceStack(Generic[A]):
|
|
1266
1258
|
"""
|
1267
1259
|
return StateTraceStack().merge(self, other)
|
1268
1260
|
|
1261
|
+
def state_subset(self, state_type: type) -> List:
|
1262
|
+
"""
|
1263
|
+
Get a subset of states of a specific type from the ``StateTraceStack``.
|
1264
|
+
|
1265
|
+
This method filters the states in the ``StateTraceStack`` and returns only
|
1266
|
+
those that match the specified state type.
|
1267
|
+
|
1268
|
+
Args:
|
1269
|
+
state_type (type): The type of state to filter by. This should be
|
1270
|
+
a subclass of State or State itself.
|
1271
|
+
|
1272
|
+
Returns:
|
1273
|
+
List[State]: A list containing all states in the ``StateTraceStack``
|
1274
|
+
that are instances of the specified state_type.
|
1275
|
+
|
1276
|
+
Example:
|
1277
|
+
>>> stack = StateTraceStack()
|
1278
|
+
>>> # Assume stack has been populated with various state types
|
1279
|
+
>>> short_term_states = stack.state_subset(ShortTermState)
|
1280
|
+
"""
|
1281
|
+
return [st for st in self.states if isinstance(st, state_type)]
|
1282
|
+
|
1269
1283
|
def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
|
1270
1284
|
"""
|
1271
1285
|
Assign new values to the states tracked by this ``StateTraceStack``.
|
@@ -1292,35 +1306,68 @@ class StateTraceStack(Generic[A]):
|
|
1292
1306
|
``StateTraceStack``'s states list.
|
1293
1307
|
"""
|
1294
1308
|
if len(state_vals) != len(self.states):
|
1295
|
-
raise ValueError(
|
1296
|
-
|
1309
|
+
raise ValueError(
|
1310
|
+
'The length of the state values must be equal to the states. '
|
1311
|
+
f'Bug got {len(state_vals)} and {len(self.states)}'
|
1312
|
+
)
|
1297
1313
|
for st, written, val in zip(self.states, self.been_writen, state_vals):
|
1298
1314
|
if written:
|
1299
1315
|
st.value = val
|
1300
1316
|
else:
|
1301
1317
|
st.restore_value(val)
|
1302
1318
|
|
1303
|
-
def
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
those that match the specified state type.
|
1309
|
-
|
1310
|
-
Args:
|
1311
|
-
state_type (type): The type of state to filter by. This should be
|
1312
|
-
a subclass of State or State itself.
|
1313
|
-
|
1314
|
-
Returns:
|
1315
|
-
List[State]: A list containing all states in the ``StateTraceStack``
|
1316
|
-
that are instances of the specified state_type.
|
1317
|
-
|
1318
|
-
Example:
|
1319
|
-
>>> stack = StateTraceStack()
|
1320
|
-
>>> # Assume stack has been populated with various state types
|
1321
|
-
>>> short_term_states = stack.state_subset(ShortTermState)
|
1319
|
+
def assign_state_vals_v2(
|
1320
|
+
self: StateTraceStack,
|
1321
|
+
read_state_vals: Sequence[PyTree],
|
1322
|
+
write_state_vals: Sequence[PyTree],
|
1323
|
+
):
|
1322
1324
|
"""
|
1323
|
-
|
1325
|
+
Write back state values to their corresponding states after computation.
|
1326
|
+
|
1327
|
+
This function updates the state values based on whether they were written to
|
1328
|
+
during the computation. If a state was written to, it gets the new written value.
|
1329
|
+
If not, it restores its original read value.
|
1330
|
+
|
1331
|
+
Parameters
|
1332
|
+
----------
|
1333
|
+
read_state_vals : sequence of PyTree
|
1334
|
+
The original state values that were read at the beginning.
|
1335
|
+
write_state_vals : sequence of PyTree
|
1336
|
+
The new state values that were written during computation.
|
1337
|
+
|
1338
|
+
Examples
|
1339
|
+
--------
|
1340
|
+
Basic usage in a compilation context:
|
1341
|
+
|
1342
|
+
.. code-block:: python
|
1343
|
+
|
1344
|
+
>>> import brainstate
|
1345
|
+
>>> import jax.numpy as jnp
|
1346
|
+
>>>
|
1347
|
+
>>> # Create states
|
1348
|
+
>>> state1 = brainstate.State(jnp.array([1.0, 2.0]))
|
1349
|
+
>>> state2 = brainstate.State(jnp.array([3.0, 4.0]))
|
1350
|
+
>>>
|
1351
|
+
>>> def f(x):
|
1352
|
+
... state1.value += x # This state will be written
|
1353
|
+
... return state1.value + state2.value # state2 is only read
|
1354
|
+
>>>
|
1355
|
+
>>> # During compilation, state values are collected and managed
|
1356
|
+
>>> # write_back_state_values ensures proper state management
|
1357
|
+
"""
|
1358
|
+
if len(self.states) != len(self.been_writen):
|
1359
|
+
raise ValueError('The length of the state values must be equal to the states. ')
|
1360
|
+
if len(read_state_vals) != len(self.states):
|
1361
|
+
raise ValueError('The length of the read state values must be equal to the states. ')
|
1362
|
+
if len(write_state_vals) != len(self.states):
|
1363
|
+
raise ValueError('The length of the write state values must be equal to the states. ')
|
1364
|
+
for st, write, val_r, val_w in zip(
|
1365
|
+
self.states, self.been_writen, read_state_vals, write_state_vals
|
1366
|
+
):
|
1367
|
+
if write:
|
1368
|
+
st.value = val_w
|
1369
|
+
else:
|
1370
|
+
st.restore_value(val_r)
|
1324
1371
|
|
1325
1372
|
|
1326
1373
|
class TreefyState(Generic[A], PrettyObject):
|
brainstate/_state_test.py
CHANGED
brainstate/_utils.py
CHANGED