brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/mixin.py CHANGED
@@ -15,360 +15,352 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from typing import Sequence, Optional, TypeVar
19
- from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
18
+ from __future__ import annotations
19
+
20
+ from typing import (Sequence, Optional, TypeVar, _SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
20
21
 
21
22
  from brainstate.typing import PyTree
22
23
 
23
24
  T = TypeVar('T')
24
- State = None
25
-
26
25
 
27
26
  __all__ = [
28
- 'Mixin',
29
- 'DelayedInit',
30
- 'DelayedInitializer',
31
- 'AlignPost',
32
- 'BindCondData',
33
- 'UpdateReturn',
34
-
35
- # types
36
- 'JointTypes',
37
- 'OneOfTypes',
38
-
39
- # behavior modes
40
- 'Mode',
41
- 'JointMode',
42
- 'Batching',
43
- 'Training',
27
+ 'Mixin',
28
+ 'ParamDesc',
29
+ 'ParamDescriber',
30
+ 'AlignPost',
31
+ 'BindCondData',
32
+ 'UpdateReturn',
33
+
34
+ # types
35
+ 'JointTypes',
36
+ 'OneOfTypes',
37
+
38
+ # behavior modes
39
+ 'Mode',
40
+ 'JointMode',
41
+ 'Batching',
42
+ 'Training',
44
43
  ]
45
44
 
46
45
 
47
- def _get_state():
48
- global State
49
- if State is None:
50
- from brainstate._state import State
51
- return State
52
-
53
-
54
46
  class Mixin(object):
55
- """Base Mixin object.
47
+ """Base Mixin object.
56
48
 
57
- The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
58
- """
59
- pass
49
+ The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
50
+ """
51
+ pass
60
52
 
61
53
 
62
- class DelayedInit(Mixin):
63
- """
64
- :py:class:`~.Mixin` indicates the function for describing initialization parameters.
54
+ class ParamDesc(Mixin):
55
+ """
56
+ :py:class:`~.Mixin` indicates the function for describing initialization parameters.
65
57
 
66
- This mixin enables the subclass has a classmethod ``delayed``, which
67
- produces an instance of :py:class:`~.DelayedInitializer`.
58
+ This mixin enables the subclass has a classmethod ``desc``, which
59
+ produces an instance of :py:class:`~.ParamDescriber`.
68
60
 
69
- Note this Mixin can be applied in any Python object.
70
- """
61
+ Note this Mixin can be applied in any Python object.
62
+ """
71
63
 
72
- non_hashable_params: Optional[Sequence[str]] = None
64
+ non_hashable_params: Optional[Sequence[str]] = None
73
65
 
74
- @classmethod
75
- def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
76
- return DelayedInitializer(cls, *args, **kwargs)
66
+ @classmethod
67
+ def desc(cls, *args, **kwargs) -> 'ParamDescriber':
68
+ return ParamDescriber(cls, *args, **kwargs)
77
69
 
78
70
 
79
71
  class HashableDict(dict):
80
- def __hash__(self):
81
- return hash(tuple(sorted(self.items())))
72
+ def __hash__(self):
73
+ return hash(tuple(sorted(self.items())))
82
74
 
83
75
 
84
76
  class NoSubclassMeta(type):
85
- def __new__(cls, name, bases, classdict):
86
- for b in bases:
87
- if isinstance(b, NoSubclassMeta):
88
- raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
89
- return type.__new__(cls, name, bases, dict(classdict))
77
+ def __new__(cls, name, bases, classdict):
78
+ for b in bases:
79
+ if isinstance(b, NoSubclassMeta):
80
+ raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
81
+ return type.__new__(cls, name, bases, dict(classdict))
90
82
 
91
83
 
92
- class DelayedInitializer(metaclass=NoSubclassMeta):
93
- """
94
- DelayedInit initialization for parameter describers.
95
- """
84
+ class ParamDescriber(metaclass=NoSubclassMeta):
85
+ """
86
+ ParamDesc initialization for parameter describers.
87
+ """
96
88
 
97
- def __init__(self, cls: T, *desc_tuple, **desc_dict):
98
- self.cls: type = cls
89
+ def __init__(self, cls: T, *desc_tuple, **desc_dict):
90
+ self.cls: type = cls
99
91
 
100
- # arguments
101
- self.args = desc_tuple
102
- self.kwargs = desc_dict
92
+ # arguments
93
+ self.args = desc_tuple
94
+ self.kwargs = desc_dict
103
95
 
104
- # identifier
105
- self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
96
+ # identifier
97
+ self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
106
98
 
107
- def __call__(self, *args, **kwargs) -> T:
108
- return self.cls(*self.args, *args, **self.kwargs, **kwargs)
99
+ def __call__(self, *args, **kwargs) -> T:
100
+ return self.cls(*self.args, *args, **self.kwargs, **kwargs)
109
101
 
110
- def init(self, *args, **kwargs):
111
- return self.__call__(*args, **kwargs)
102
+ def init(self, *args, **kwargs):
103
+ return self.__call__(*args, **kwargs)
112
104
 
113
- def __instancecheck__(self, instance):
114
- if not isinstance(instance, DelayedInitializer):
115
- return False
116
- if not issubclass(instance.cls, self.cls):
117
- return False
118
- return True
105
+ def __instancecheck__(self, instance):
106
+ if not isinstance(instance, ParamDescriber):
107
+ return False
108
+ if not issubclass(instance.cls, self.cls):
109
+ return False
110
+ return True
119
111
 
120
- @classmethod
121
- def __class_getitem__(cls, item: type):
122
- return DelayedInitializer(item)
112
+ @classmethod
113
+ def __class_getitem__(cls, item: type):
114
+ return ParamDescriber(item)
123
115
 
124
- @property
125
- def identifier(self):
126
- return self._identifier
116
+ @property
117
+ def identifier(self):
118
+ return self._identifier
127
119
 
128
- @identifier.setter
129
- def identifier(self, value):
130
- raise AttributeError('Cannot set the identifier.')
120
+ @identifier.setter
121
+ def identifier(self, value):
122
+ raise AttributeError('Cannot set the identifier.')
131
123
 
132
124
 
133
125
  class AlignPost(Mixin):
134
- """
135
- Align post MixIn.
126
+ """
127
+ Align post MixIn.
136
128
 
137
- This class provides a ``align_post_input_add()`` function for
138
- add external currents.
139
- """
129
+ This class provides a ``align_post_input_add()`` function for
130
+ add external currents.
131
+ """
140
132
 
141
- def align_post_input_add(self, *args, **kwargs):
142
- raise NotImplementedError
133
+ def align_post_input_add(self, *args, **kwargs):
134
+ raise NotImplementedError
143
135
 
144
136
 
145
137
  class BindCondData(Mixin):
146
- """Bind temporary conductance data.
138
+ """Bind temporary conductance data.
147
139
 
148
140
 
149
- """
150
- _conductance: Optional
141
+ """
142
+ _conductance: Optional
151
143
 
152
- def bind_cond(self, conductance):
153
- self._conductance = conductance
144
+ def bind_cond(self, conductance):
145
+ self._conductance = conductance
154
146
 
155
- def unbind_cond(self):
156
- self._conductance = None
147
+ def unbind_cond(self):
148
+ self._conductance = None
157
149
 
158
150
 
159
151
  class UpdateReturn(Mixin):
160
152
 
161
- def update_return(self) -> PyTree:
162
- """
163
- The update function return of the model.
153
+ def update_return(self) -> PyTree:
154
+ """
155
+ The update function return of the model.
164
156
 
165
- It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
157
+ It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
166
158
 
167
- """
168
- raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
159
+ """
160
+ raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
169
161
 
170
- def update_return_info(self) -> PyTree:
171
- """
172
- The update return information of the model.
162
+ def update_return_info(self) -> PyTree:
163
+ """
164
+ The update return information of the model.
173
165
 
174
- It should be a pytree, with each element as a ``jax.Array``.
166
+ It should be a pytree, with each element as a ``jax.Array``.
175
167
 
176
- .. note::
177
- Should not include the batch axis and batch size.
178
- These information will be inferred from the ``mode`` attribute.
168
+ .. note::
169
+ Should not include the batch axis and batch in_size.
170
+ These information will be inferred from the ``mode`` attribute.
179
171
 
180
- """
181
- raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
172
+ """
173
+ raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
182
174
 
183
175
 
184
176
  class _MetaUnionType(type):
185
- def __new__(cls, name, bases, dct):
186
- if isinstance(bases, type):
187
- bases = (bases,)
188
- elif isinstance(bases, (list, tuple)):
189
- bases = tuple(bases)
190
- for base in bases:
191
- assert isinstance(base, type), f'Must be type. But got {base}'
192
- else:
193
- raise TypeError(f'Must be type. But got {bases}')
194
- return super().__new__(cls, name, bases, dct)
177
+ def __new__(cls, name, bases, dct):
178
+ if isinstance(bases, type):
179
+ bases = (bases,)
180
+ elif isinstance(bases, (list, tuple)):
181
+ bases = tuple(bases)
182
+ for base in bases:
183
+ assert isinstance(base, type), f'Must be type. But got {base}'
184
+ else:
185
+ raise TypeError(f'Must be type. But got {bases}')
186
+ return super().__new__(cls, name, bases, dct)
195
187
 
196
- def __instancecheck__(self, other):
197
- cls_of_other = other.__class__
198
- return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
188
+ def __instancecheck__(self, other):
189
+ cls_of_other = other.__class__
190
+ return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
199
191
 
200
- def __subclasscheck__(self, subclass):
201
- return all([issubclass(subclass, cls) for cls in self.__bases__])
192
+ def __subclasscheck__(self, subclass):
193
+ return all([issubclass(subclass, cls) for cls in self.__bases__])
202
194
 
203
195
 
204
196
  class _JointGenericAlias(_UnionGenericAlias, _root=True):
205
- def __subclasscheck__(self, subclass):
206
- return all([issubclass(subclass, cls) for cls in set(self.__args__)])
197
+ def __subclasscheck__(self, subclass):
198
+ return all([issubclass(subclass, cls) for cls in set(self.__args__)])
207
199
 
208
200
 
209
201
  @_SpecialForm
210
202
  def JointTypes(self, parameters):
211
- """Joint types; JointTypes[X, Y] means both X and Y.
203
+ """Joint types; JointTypes[X, Y] means both X and Y.
212
204
 
213
- To define a union, use e.g. Union[int, str].
205
+ To define a union, use e.g. Union[int, str].
214
206
 
215
- Details:
216
- - The arguments must be types and there must be at least one.
217
- - None as an argument is a special case and is replaced by `type(None)`.
218
- - Unions of unions are flattened, e.g.::
207
+ Details:
208
+ - The arguments must be types and there must be at least one.
209
+ - None as an argument is a special case and is replaced by `type(None)`.
210
+ - Unions of unions are flattened, e.g.::
219
211
 
220
- JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
212
+ JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
221
213
 
222
- - Unions of a single argument vanish, e.g.::
214
+ - Unions of a single argument vanish, e.g.::
223
215
 
224
- JointTypes[int] == int # The constructor actually returns int
216
+ JointTypes[int] == int # The constructor actually returns int
225
217
 
226
- - Redundant arguments are skipped, e.g.::
218
+ - Redundant arguments are skipped, e.g.::
227
219
 
228
- JointTypes[int, str, int] == JointTypes[int, str]
220
+ JointTypes[int, str, int] == JointTypes[int, str]
229
221
 
230
- - When comparing unions, the argument order is ignored, e.g.::
222
+ - When comparing unions, the argument order is ignored, e.g.::
231
223
 
232
- JointTypes[int, str] == JointTypes[str, int]
224
+ JointTypes[int, str] == JointTypes[str, int]
233
225
 
234
- - You cannot subclass or instantiate a JointTypes.
235
- - You can use Optional[X] as a shorthand for JointTypes[X, None].
236
- """
237
- if parameters == ():
238
- raise TypeError("Cannot take a Joint of no types.")
239
- if not isinstance(parameters, tuple):
240
- parameters = (parameters,)
241
- msg = "JointTypes[arg, ...]: each arg must be a type."
242
- parameters = tuple(_type_check(p, msg) for p in parameters)
243
- parameters = _remove_dups_flatten(parameters)
244
- if len(parameters) == 1:
245
- return parameters[0]
246
- if len(parameters) == 2 and type(None) in parameters:
247
- return _UnionGenericAlias(self, parameters, name="Optional")
248
- return _JointGenericAlias(self, parameters)
226
+ - You cannot subclass or instantiate a JointTypes.
227
+ - You can use Optional[X] as a shorthand for JointTypes[X, None].
228
+ """
229
+ if parameters == ():
230
+ raise TypeError("Cannot take a Joint of no types.")
231
+ if not isinstance(parameters, tuple):
232
+ parameters = (parameters,)
233
+ msg = "JointTypes[arg, ...]: each arg must be a type."
234
+ parameters = tuple(_type_check(p, msg) for p in parameters)
235
+ parameters = _remove_dups_flatten(parameters)
236
+ if len(parameters) == 1:
237
+ return parameters[0]
238
+ if len(parameters) == 2 and type(None) in parameters:
239
+ return _UnionGenericAlias(self, parameters, name="Optional")
240
+ return _JointGenericAlias(self, parameters)
249
241
 
250
242
 
251
243
  @_SpecialForm
252
244
  def OneOfTypes(self, parameters):
253
- """Sole type; OneOfTypes[X, Y] means either X or Y.
245
+ """Sole type; OneOfTypes[X, Y] means either X or Y.
254
246
 
255
- To define a union, use e.g. OneOfTypes[int, str]. Details:
256
- - The arguments must be types and there must be at least one.
257
- - None as an argument is a special case and is replaced by
258
- type(None).
259
- - Unions of unions are flattened, e.g.::
247
+ To define a union, use e.g. OneOfTypes[int, str]. Details:
248
+ - The arguments must be types and there must be at least one.
249
+ - None as an argument is a special case and is replaced by
250
+ type(None).
251
+ - Unions of unions are flattened, e.g.::
260
252
 
261
- assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
253
+ assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
262
254
 
263
- - Unions of a single argument vanish, e.g.::
255
+ - Unions of a single argument vanish, e.g.::
264
256
 
265
- assert OneOfTypes[int] == int # The constructor actually returns int
257
+ assert OneOfTypes[int] == int # The constructor actually returns int
266
258
 
267
- - Redundant arguments are skipped, e.g.::
259
+ - Redundant arguments are skipped, e.g.::
268
260
 
269
- assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
261
+ assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
270
262
 
271
- - When comparing unions, the argument order is ignored, e.g.::
263
+ - When comparing unions, the argument order is ignored, e.g.::
272
264
 
273
- assert OneOfTypes[int, str] == OneOfTypes[str, int]
265
+ assert OneOfTypes[int, str] == OneOfTypes[str, int]
274
266
 
275
- - You cannot subclass or instantiate a union.
276
- - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
277
- """
278
- if parameters == ():
279
- raise TypeError("Cannot take a Sole of no types.")
280
- if not isinstance(parameters, tuple):
281
- parameters = (parameters,)
282
- msg = "OneOfTypes[arg, ...]: each arg must be a type."
283
- parameters = tuple(_type_check(p, msg) for p in parameters)
284
- parameters = _remove_dups_flatten(parameters)
285
- if len(parameters) == 1:
286
- return parameters[0]
287
- if len(parameters) == 2 and type(None) in parameters:
288
- return _UnionGenericAlias(self, parameters, name="Optional")
289
- return _UnionGenericAlias(self, parameters)
267
+ - You cannot subclass or instantiate a union.
268
+ - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
269
+ """
270
+ if parameters == ():
271
+ raise TypeError("Cannot take a Sole of no types.")
272
+ if not isinstance(parameters, tuple):
273
+ parameters = (parameters,)
274
+ msg = "OneOfTypes[arg, ...]: each arg must be a type."
275
+ parameters = tuple(_type_check(p, msg) for p in parameters)
276
+ parameters = _remove_dups_flatten(parameters)
277
+ if len(parameters) == 1:
278
+ return parameters[0]
279
+ if len(parameters) == 2 and type(None) in parameters:
280
+ return _UnionGenericAlias(self, parameters, name="Optional")
281
+ return _UnionGenericAlias(self, parameters)
290
282
 
291
283
 
292
284
  class Mode(Mixin):
293
- """
294
- Base class for computation behaviors.
295
- """
296
-
297
- def __repr__(self):
298
- return self.__class__.__name__
299
-
300
- def __eq__(self, other: 'Mode'):
301
- assert isinstance(other, Mode)
302
- return other.__class__ == self.__class__
303
-
304
- def is_a(self, mode: type):
305
285
  """
306
- Check whether the mode is exactly the desired mode.
286
+ Base class for computation behaviors.
307
287
  """
308
- assert isinstance(mode, type), 'Must be a type.'
309
- return self.__class__ == mode
310
288
 
311
- def has(self, mode: type):
312
- """
313
- Check whether the mode is included in the desired mode.
314
- """
315
- assert isinstance(mode, type), 'Must be a type.'
316
- return isinstance(self, mode)
289
+ def __repr__(self):
290
+ return self.__class__.__name__
317
291
 
292
+ def __eq__(self, other: 'Mode'):
293
+ assert isinstance(other, Mode)
294
+ return other.__class__ == self.__class__
318
295
 
319
- class JointMode(Mode):
320
- """
321
- Joint mode.
322
- """
296
+ def is_a(self, mode: type):
297
+ """
298
+ Check whether the mode is exactly the desired mode.
299
+ """
300
+ assert isinstance(mode, type), 'Must be a type.'
301
+ return self.__class__ == mode
323
302
 
324
- def __init__(self, *modes: Mode):
325
- for m_ in modes:
326
- if not isinstance(m_, Mode):
327
- raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
328
- self.modes = tuple(modes)
329
- self.types = set([m.__class__ for m in modes])
303
+ def has(self, mode: type):
304
+ """
305
+ Check whether the mode is included in the desired mode.
306
+ """
307
+ assert isinstance(mode, type), 'Must be a type.'
308
+ return isinstance(self, mode)
330
309
 
331
- def __repr__(self):
332
- return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
333
310
 
334
- def has(self, mode: type):
335
- """
336
- Check whether the mode is included in the desired mode.
337
- """
338
- assert isinstance(mode, type), 'Must be a type.'
339
- return any([issubclass(cls, mode) for cls in self.types])
340
-
341
- def is_a(self, cls: type):
342
- """
343
- Check whether the mode is exactly the desired mode.
311
+ class JointMode(Mode):
344
312
  """
345
- return JointTypes[tuple(self.types)] == cls
346
-
347
- def __getattr__(self, item):
313
+ Joint mode.
348
314
  """
349
- Get the attribute from the mode.
350
315
 
351
- If the attribute is not found in the mode, then it will be searched in the base class.
352
- """
353
- if item in ['modes', 'types']:
354
- return super().__getattribute__(item)
355
- for m in self.modes:
356
- if hasattr(m, item):
357
- return getattr(m, item)
358
- return super().__getattribute__(item)
316
+ def __init__(self, *modes: Mode):
317
+ for m_ in modes:
318
+ if not isinstance(m_, Mode):
319
+ raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
320
+ self.modes = tuple(modes)
321
+ self.types = set([m.__class__ for m in modes])
322
+
323
+ def __repr__(self):
324
+ return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
325
+
326
+ def has(self, mode: type):
327
+ """
328
+ Check whether the mode is included in the desired mode.
329
+ """
330
+ assert isinstance(mode, type), 'Must be a type.'
331
+ return any([issubclass(cls, mode) for cls in self.types])
332
+
333
+ def is_a(self, cls: type):
334
+ """
335
+ Check whether the mode is exactly the desired mode.
336
+ """
337
+ return JointTypes[tuple(self.types)] == cls
338
+
339
+ def __getattr__(self, item):
340
+ """
341
+ Get the attribute from the mode.
342
+
343
+ If the attribute is not found in the mode, then it will be searched in the base class.
344
+ """
345
+ if item in ['modes', 'types']:
346
+ return super().__getattribute__(item)
347
+ for m in self.modes:
348
+ if hasattr(m, item):
349
+ return getattr(m, item)
350
+ return super().__getattribute__(item)
359
351
 
360
352
 
361
353
  class Batching(Mode):
362
- """Batching mode."""
354
+ """Batching mode."""
363
355
 
364
- def __init__(self, batch_size: int = 1, batch_axis: int = 0):
365
- self.batch_size = batch_size
366
- self.batch_axis = batch_axis
356
+ def __init__(self, batch_size: int = 1, batch_axis: int = 0):
357
+ self.batch_size = batch_size
358
+ self.batch_axis = batch_axis
367
359
 
368
- def __repr__(self):
369
- return f'{self.__class__.__name__}(size={self.batch_size}, axis={self.batch_axis})'
360
+ def __repr__(self):
361
+ return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
370
362
 
371
363
 
372
364
  class Training(Mode):
373
- """Training mode."""
374
- pass
365
+ """Training mode."""
366
+ pass