brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,13 @@
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
17
 
18
+ """
19
+ Pretty representation utilities for creating human-readable string representations.
20
+
21
+ This module provides utilities for creating customizable pretty representations of
22
+ objects, with support for nested structures and circular reference detection.
23
+ """
24
+
18
25
  import dataclasses
19
26
  import threading
20
27
  from abc import ABC, abstractmethod
@@ -38,6 +45,21 @@ B = TypeVar('B')
38
45
  class PrettyType:
39
46
  """
40
47
  Configuration for pretty representation of objects.
48
+
49
+ Attributes
50
+ ----------
51
+ type : Union[str, type]
52
+ The type name or type object to display.
53
+ start : str, default='('
54
+ The opening delimiter for the representation.
55
+ end : str, default=')'
56
+ The closing delimiter for the representation.
57
+ value_sep : str, default='='
58
+ The separator between keys and values.
59
+ elem_indent : str, default=' '
60
+ The indentation for nested elements.
61
+ empty_repr : str, default=''
62
+ The representation for empty objects.
41
63
  """
42
64
  type: Union[str, type]
43
65
  start: str = '('
@@ -51,6 +73,17 @@ class PrettyType:
51
73
  class PrettyAttr:
52
74
  """
53
75
  Configuration for pretty representation of attributes.
76
+
77
+ Attributes
78
+ ----------
79
+ key : str
80
+ The attribute name or key.
81
+ value : Union[str, Any]
82
+ The attribute value.
83
+ start : str, default=''
84
+ Optional prefix for the attribute.
85
+ end : str, default=''
86
+ Optional suffix for the attribute.
54
87
  """
55
88
  key: str
56
89
  value: Union[str, Any]
@@ -62,23 +95,54 @@ class PrettyRepr(ABC):
62
95
  """
63
96
  Interface for pretty representation of objects.
64
97
 
65
- Example::
66
-
67
- >>> class MyObject(PrettyRepr):
68
- >>> def __pretty_repr__(self):
69
- >>> yield PrettyType(type='MyObject', start='{', end='}')
70
- >>> yield PrettyAttr('key', self.key)
71
- >>> yield PrettyAttr('value', self.value)
72
-
98
+ This abstract base class provides a framework for creating custom
99
+ pretty representations of objects by yielding PrettyType and PrettyAttr
100
+ instances.
101
+
102
+ Examples
103
+ --------
104
+ .. code-block:: python
105
+
106
+ >>> class MyObject(PrettyRepr):
107
+ ... def __init__(self, key, value):
108
+ ... self.key = key
109
+ ... self.value = value
110
+ ...
111
+ ... def __pretty_repr__(self):
112
+ ... yield PrettyType(type='MyObject', start='{', end='}')
113
+ ... yield PrettyAttr('key', self.key)
114
+ ... yield PrettyAttr('value', self.value)
115
+ ...
116
+ >>> obj = MyObject('foo', 42)
117
+ >>> print(obj)
118
+ MyObject{
119
+ key=foo,
120
+ value=42
121
+ }
73
122
  """
74
123
  __slots__ = ()
75
124
 
76
125
  @abstractmethod
77
126
  def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
127
+ """
128
+ Generate the pretty representation of the object.
129
+
130
+ Yields
131
+ ------
132
+ Union[PrettyType, PrettyAttr]
133
+ First yield should be PrettyType, followed by PrettyAttr instances.
134
+ """
78
135
  raise NotImplementedError
79
136
 
80
137
  def __repr__(self) -> str:
81
- # repr the individual object with the pretty representation
138
+ """
139
+ Generate string representation using pretty representation.
140
+
141
+ Returns
142
+ -------
143
+ str
144
+ The formatted string representation of the object.
145
+ """
82
146
  return pretty_repr_object(self)
83
147
 
84
148
 
@@ -174,9 +238,33 @@ def pretty_repr_object(obj: PrettyRepr) -> str:
174
238
  class MappingReprMixin(Mapping[A, B]):
175
239
  """
176
240
  Mapping mixin for pretty representation.
241
+
242
+ This mixin provides a default pretty representation for mapping-like objects.
243
+
244
+ Examples
245
+ --------
246
+ .. code-block:: python
247
+
248
+ >>> class MyMapping(dict, MappingReprMixin):
249
+ ... pass
250
+ ...
251
+ >>> m = MyMapping({'a': 1, 'b': 2})
252
+ >>> print(m)
253
+ {
254
+ 'a': 1,
255
+ 'b': 2
256
+ }
177
257
  """
178
258
 
179
- def __pretty_repr__(self):
259
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
260
+ """
261
+ Generate pretty representation for mapping.
262
+
263
+ Yields
264
+ ------
265
+ Union[PrettyType, PrettyAttr]
266
+ PrettyType followed by PrettyAttr for each key-value pair.
267
+ """
180
268
  yield PrettyType(type='', value_sep=': ', start='{', end='}')
181
269
 
182
270
  for key, value in self.items():
@@ -187,11 +275,37 @@ class MappingReprMixin(Mapping[A, B]):
187
275
  class PrettyMapping(PrettyRepr):
188
276
  """
189
277
  Pretty representation of a mapping.
278
+
279
+ Attributes
280
+ ----------
281
+ mapping : Mapping
282
+ The mapping to represent.
283
+ type_name : str, default=''
284
+ Optional type name to display.
285
+
286
+ Examples
287
+ --------
288
+ .. code-block:: python
289
+
290
+ >>> m = PrettyMapping({'a': 1, 'b': 2}, type_name='MyDict')
291
+ >>> print(m)
292
+ MyDict{
293
+ 'a': 1,
294
+ 'b': 2
295
+ }
190
296
  """
191
297
  mapping: Mapping
192
298
  type_name: str = ''
193
299
 
194
- def __pretty_repr__(self):
300
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
301
+ """
302
+ Generate pretty representation for the mapping.
303
+
304
+ Yields
305
+ ------
306
+ Union[PrettyType, PrettyAttr]
307
+ PrettyType followed by PrettyAttr for each key-value pair.
308
+ """
195
309
  yield PrettyType(type=self.type_name, value_sep=': ', start='{', end='}')
196
310
 
197
311
  for key, value in self.mapping.items():
@@ -220,9 +334,9 @@ class PrettyReprContext(threading.local):
220
334
  CONTEXT = PrettyReprContext()
221
335
 
222
336
 
223
- def _default_repr_object(node):
337
+ def _default_repr_object(node: Any) -> Iterator[PrettyType]:
224
338
  """
225
- Generates a default pretty representation for an object.
339
+ Generate a default pretty representation for an object.
226
340
 
227
341
  This function yields a `PrettyType` instance that represents the type
228
342
  of the given object. It is used as a default method for representing
@@ -242,9 +356,9 @@ def _default_repr_object(node):
242
356
  yield PrettyType(type=type(node))
243
357
 
244
358
 
245
- def _default_repr_attr(node):
359
+ def _default_repr_attr(node: Any) -> Iterator[PrettyAttr]:
246
360
  """
247
- Generates a default pretty representation for the attributes of an object.
361
+ Generate a default pretty representation for the attributes of an object.
248
362
 
249
363
  This function iterates over the attributes of the given object and yields
250
364
  a `PrettyAttr` instance for each attribute that does not start with an
@@ -269,32 +383,52 @@ def _default_repr_attr(node):
269
383
 
270
384
 
271
385
  def yield_unique_pretty_repr_items(
272
- node,
386
+ node: Any,
273
387
  repr_object: Optional[Callable] = None,
274
388
  repr_attr: Optional[Callable] = None
275
- ):
389
+ ) -> Iterator[Union[PrettyType, PrettyAttr]]:
276
390
  """
277
- Generates a pretty representation of an object while avoiding duplicate representations.
391
+ Generate a pretty representation of an object while avoiding duplicate representations.
278
392
 
279
- This function is designed to yield a structured representation of an object,
280
- using custom or default methods for representing the object itself and its attributes.
281
- It ensures that each object is only represented once to prevent infinite recursion
282
- in cases of circular references.
393
+ This function yields a structured representation of an object, using custom or default
394
+ methods for representing the object itself and its attributes. It ensures that each
395
+ object is only represented once to prevent infinite recursion in cases of circular
396
+ references.
283
397
 
284
- Parameters:
398
+ Parameters
399
+ ----------
285
400
  node : Any
286
401
  The object to be represented.
287
402
  repr_object : Optional[Callable], optional
288
- A callable that yields the representation of the object itself.
403
+ A callable that yields the representation of the object itself.
289
404
  If not provided, a default representation function is used.
290
405
  repr_attr : Optional[Callable], optional
291
- A callable that yields the representation of the object's attributes.
406
+ A callable that yields the representation of the object's attributes.
292
407
  If not provided, a default attribute representation function is used.
293
408
 
294
- Yields:
409
+ Yields
410
+ ------
295
411
  Union[PrettyType, PrettyAttr]
296
- The pretty representation of the object and its attributes,
412
+ The pretty representation of the object and its attributes,
297
413
  avoiding duplicates by tracking seen objects.
414
+
415
+ Examples
416
+ --------
417
+ .. code-block:: python
418
+
419
+ >>> class Node:
420
+ ... def __init__(self, value, next=None):
421
+ ... self.value = value
422
+ ... self.next = next
423
+ ...
424
+ >>> # Create circular reference
425
+ >>> node1 = Node(1)
426
+ >>> node2 = Node(2, node1)
427
+ >>> node1.next = node2
428
+ ...
429
+ >>> # This will handle circular reference gracefully
430
+ >>> for item in yield_unique_pretty_repr_items(node1):
431
+ ... print(item)
298
432
  """
299
433
  if repr_object is None:
300
434
  repr_object = _default_repr_object