brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
16
16
 
17
17
  __all__ = [
18
18
  'BrainStateError',
19
- 'TraceContextError',
19
+ 'BatchAxisError',
20
20
  ]
21
21
 
22
22
 
@@ -27,29 +27,19 @@ class BrainStateError(Exception):
27
27
  This exception is raised when a BrainState-specific error occurs during
28
28
  the execution of the program. It serves as a base class for more specific
29
29
  BrainState exceptions.
30
-
31
- Attributes:
32
- Inherits all attributes from the built-in Exception class.
33
-
34
- Usage::
35
-
36
- raise BrainStateError("A BrainState-specific error occurred.")
37
30
  """
38
31
  pass
39
32
 
40
33
 
41
- class TraceContextError(BrainStateError):
34
+ class BatchAxisError(BrainStateError):
42
35
  """
43
- A custom exception class for trace context-related errors in BrainState.
36
+ Exception raised for errors related to batch axis operations.
44
37
 
45
- This exception is raised when an error occurs specifically related to
46
- trace context operations or manipulations within the BrainState framework.
38
+ This custom exception is used to indicate errors that occur during
39
+ batch processing or vectorization operations, particularly in the
40
+ context of state management in the BrainState framework.
47
41
 
48
- Attributes:
49
- Inherits all attributes from the BrainStateError class.
50
-
51
- Usage::
52
-
53
- raise TraceContextError("An error occurred while handling trace context.")
42
+ Inherits from:
43
+ BrainStateError: The base error class for BrainState-related exceptions.
54
44
  """
55
- pass
45
+ __module__ = 'brainstate.transform'
brainstate/_state.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -112,9 +112,12 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
112
112
  If you want to check the tree structure of the value once the new value is assigned,
113
113
  you can use this context manager.
114
114
 
115
- Example::
115
+ Examples
116
+ --------
117
+
118
+ .. code-block:: python
116
119
 
117
- >>> import brainstate as brainstate
120
+ >>> import brainstate
118
121
  >>> import jax.numpy as jnp
119
122
  >>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
120
123
  >>> with brainstate.check_state_value_tree():
@@ -158,10 +161,13 @@ def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
158
161
  """
159
162
  The context manager to check whether the state is valid to trace.
160
163
 
161
- Example::
164
+ Example
165
+ -------
166
+
167
+ .. code-block:: python
162
168
 
163
169
  >>> import jax
164
- >>> import brainstate as brainstate
170
+ >>> import brainstate
165
171
  >>> import jax.numpy as jnp
166
172
  >>>
167
173
  >>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
@@ -213,7 +219,11 @@ class State(Generic[A], PrettyObject):
213
219
  name (Optional[str]): An optional name for the state.
214
220
  **metadata: Additional metadata to be stored with the state.
215
221
 
216
- Example:
222
+ Example
223
+ -------
224
+
225
+ .. code-block:: python
226
+
217
227
  >>> class MyState(State):
218
228
  ... pass
219
229
  >>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
@@ -910,24 +920,6 @@ class StateTraceStack(Generic[A]):
910
920
  The class is generic over type A, allowing for type-safe usage with
911
921
  different types of State objects.
912
922
 
913
- Attributes:
914
- states (List[State]): A list of all State objects encountered during tracing.
915
- been_writen (List[bool]): A parallel list to states, indicating whether each state has been written to.
916
- _state_id_index (dict): A dictionary mapping state ids to their index in the states list.
917
- _original_state_values (List): A list of the original values of all states when first encountered.
918
- _jax_trace_new_arg (Callable): A function used to transform state values during tracing.
919
-
920
- Methods:
921
- __enter__: Enters a new tracing context.
922
- __exit__: Exits the current tracing context.
923
- read_its_value: Records a read operation on a state.
924
- write_its_value: Records a write operation on a state.
925
- get_state_values: Retrieves the current values of all traced states.
926
- recovery_original_values: Restores all states to their original values.
927
- merge: Merges multiple ``StateTraceStack`` instances.
928
- get_read_states: Retrieves states that were read during tracing.
929
- get_read_state_values: Retrieves values of states that were read during tracing.
930
-
931
923
  The ``StateTraceStack`` is a crucial component in implementing state-based
932
924
  computations and is particularly useful in scenarios involving automatic
933
925
  differentiation or other forms of program transformation.
@@ -992,7 +984,7 @@ class StateTraceStack(Generic[A]):
992
984
  """
993
985
  if self._jax_trace_new_arg is not None:
994
986
  # internal use
995
- state._value = jax.tree.map(self._jax_trace_new_arg, state._value)
987
+ state._value = self._jax_trace_new_arg(state)
996
988
 
997
989
  def __enter__(self) -> 'StateTraceStack':
998
990
  TRACE_CONTEXT.state_stack.append(self)
@@ -1266,6 +1258,28 @@ class StateTraceStack(Generic[A]):
1266
1258
  """
1267
1259
  return StateTraceStack().merge(self, other)
1268
1260
 
1261
+ def state_subset(self, state_type: type) -> List:
1262
+ """
1263
+ Get a subset of states of a specific type from the ``StateTraceStack``.
1264
+
1265
+ This method filters the states in the ``StateTraceStack`` and returns only
1266
+ those that match the specified state type.
1267
+
1268
+ Args:
1269
+ state_type (type): The type of state to filter by. This should be
1270
+ a subclass of State or State itself.
1271
+
1272
+ Returns:
1273
+ List[State]: A list containing all states in the ``StateTraceStack``
1274
+ that are instances of the specified state_type.
1275
+
1276
+ Example:
1277
+ >>> stack = StateTraceStack()
1278
+ >>> # Assume stack has been populated with various state types
1279
+ >>> short_term_states = stack.state_subset(ShortTermState)
1280
+ """
1281
+ return [st for st in self.states if isinstance(st, state_type)]
1282
+
1269
1283
  def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
1270
1284
  """
1271
1285
  Assign new values to the states tracked by this ``StateTraceStack``.
@@ -1292,35 +1306,68 @@ class StateTraceStack(Generic[A]):
1292
1306
  ``StateTraceStack``'s states list.
1293
1307
  """
1294
1308
  if len(state_vals) != len(self.states):
1295
- raise ValueError('The length of the state values must be equal to the states. '
1296
- f'Bug got {len(state_vals)} and {len(self.states)}')
1309
+ raise ValueError(
1310
+ 'The length of the state values must be equal to the states. '
1311
+ f'Bug got {len(state_vals)} and {len(self.states)}'
1312
+ )
1297
1313
  for st, written, val in zip(self.states, self.been_writen, state_vals):
1298
1314
  if written:
1299
1315
  st.value = val
1300
1316
  else:
1301
1317
  st.restore_value(val)
1302
1318
 
1303
- def state_subset(self, state_type: type) -> List:
1304
- """
1305
- Get a subset of states of a specific type from the ``StateTraceStack``.
1306
-
1307
- This method filters the states in the ``StateTraceStack`` and returns only
1308
- those that match the specified state type.
1309
-
1310
- Args:
1311
- state_type (type): The type of state to filter by. This should be
1312
- a subclass of State or State itself.
1313
-
1314
- Returns:
1315
- List[State]: A list containing all states in the ``StateTraceStack``
1316
- that are instances of the specified state_type.
1317
-
1318
- Example:
1319
- >>> stack = StateTraceStack()
1320
- >>> # Assume stack has been populated with various state types
1321
- >>> short_term_states = stack.state_subset(ShortTermState)
1319
+ def assign_state_vals_v2(
1320
+ self: StateTraceStack,
1321
+ read_state_vals: Sequence[PyTree],
1322
+ write_state_vals: Sequence[PyTree],
1323
+ ):
1322
1324
  """
1323
- return [st for st in self.states if isinstance(st, state_type)]
1325
+ Write back state values to their corresponding states after computation.
1326
+
1327
+ This function updates the state values based on whether they were written to
1328
+ during the computation. If a state was written to, it gets the new written value.
1329
+ If not, it restores its original read value.
1330
+
1331
+ Parameters
1332
+ ----------
1333
+ read_state_vals : sequence of PyTree
1334
+ The original state values that were read at the beginning.
1335
+ write_state_vals : sequence of PyTree
1336
+ The new state values that were written during computation.
1337
+
1338
+ Examples
1339
+ --------
1340
+ Basic usage in a compilation context:
1341
+
1342
+ .. code-block:: python
1343
+
1344
+ >>> import brainstate
1345
+ >>> import jax.numpy as jnp
1346
+ >>>
1347
+ >>> # Create states
1348
+ >>> state1 = brainstate.State(jnp.array([1.0, 2.0]))
1349
+ >>> state2 = brainstate.State(jnp.array([3.0, 4.0]))
1350
+ >>>
1351
+ >>> def f(x):
1352
+ ... state1.value += x # This state will be written
1353
+ ... return state1.value + state2.value # state2 is only read
1354
+ >>>
1355
+ >>> # During compilation, state values are collected and managed
1356
+ >>> # write_back_state_values ensures proper state management
1357
+ """
1358
+ if len(self.states) != len(self.been_writen):
1359
+ raise ValueError('The length of the state values must be equal to the states. ')
1360
+ if len(read_state_vals) != len(self.states):
1361
+ raise ValueError('The length of the read state values must be equal to the states. ')
1362
+ if len(write_state_vals) != len(self.states):
1363
+ raise ValueError('The length of the write state values must be equal to the states. ')
1364
+ for st, write, val_r, val_w in zip(
1365
+ self.states, self.been_writen, read_state_vals, write_state_vals
1366
+ ):
1367
+ if write:
1368
+ st.value = val_w
1369
+ else:
1370
+ st.restore_value(val_r)
1324
1371
 
1325
1372
 
1326
1373
  class TreefyState(Generic[A], PrettyObject):
brainstate/_state_test.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
brainstate/_utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.