brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/mixin.py CHANGED
@@ -15,360 +15,352 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from typing import Sequence, Optional, TypeVar
19
- from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
18
+ from __future__ import annotations
19
+
20
+ from typing import (Sequence, Optional, TypeVar, _SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
20
21
 
21
22
  from brainstate.typing import PyTree
22
23
 
23
24
  T = TypeVar('T')
24
- State = None
25
-
26
25
 
27
26
  __all__ = [
28
- 'Mixin',
29
- 'DelayedInit',
30
- 'DelayedInitializer',
31
- 'AlignPost',
32
- 'BindCondData',
33
- 'UpdateReturn',
34
-
35
- # types
36
- 'JointTypes',
37
- 'OneOfTypes',
38
-
39
- # behavior modes
40
- 'Mode',
41
- 'JointMode',
42
- 'Batching',
43
- 'Training',
27
+ 'Mixin',
28
+ 'ParamDesc',
29
+ 'ParamDescriber',
30
+ 'AlignPost',
31
+ 'BindCondData',
32
+ 'UpdateReturn',
33
+
34
+ # types
35
+ 'JointTypes',
36
+ 'OneOfTypes',
37
+
38
+ # behavior modes
39
+ 'Mode',
40
+ 'JointMode',
41
+ 'Batching',
42
+ 'Training',
44
43
  ]
45
44
 
46
45
 
47
- def _get_state():
48
- global State
49
- if State is None:
50
- from brainstate._state import State
51
- return State
52
-
53
-
54
46
  class Mixin(object):
55
- """Base Mixin object.
47
+ """Base Mixin object.
56
48
 
57
- The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
58
- """
59
- pass
49
+ The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
50
+ """
51
+ pass
60
52
 
61
53
 
62
- class DelayedInit(Mixin):
63
- """
64
- :py:class:`~.Mixin` indicates the function for describing initialization parameters.
54
+ class ParamDesc(Mixin):
55
+ """
56
+ :py:class:`~.Mixin` indicates the function for describing initialization parameters.
65
57
 
66
- This mixin enables the subclass has a classmethod ``delayed``, which
67
- produces an instance of :py:class:`~.DelayedInitializer`.
58
+ This mixin enables the subclass has a classmethod ``desc``, which
59
+ produces an instance of :py:class:`~.ParamDescriber`.
68
60
 
69
- Note this Mixin can be applied in any Python object.
70
- """
61
+ Note this Mixin can be applied in any Python object.
62
+ """
71
63
 
72
- non_hashable_params: Optional[Sequence[str]] = None
64
+ non_hashable_params: Optional[Sequence[str]] = None
73
65
 
74
- @classmethod
75
- def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
76
- return DelayedInitializer(cls, *args, **kwargs)
66
+ @classmethod
67
+ def desc(cls, *args, **kwargs) -> 'ParamDescriber':
68
+ return ParamDescriber(cls, *args, **kwargs)
77
69
 
78
70
 
79
71
  class HashableDict(dict):
80
- def __hash__(self):
81
- return hash(tuple(sorted(self.items())))
72
+ def __hash__(self):
73
+ return hash(tuple(sorted(self.items())))
82
74
 
83
75
 
84
76
  class NoSubclassMeta(type):
85
- def __new__(cls, name, bases, classdict):
86
- for b in bases:
87
- if isinstance(b, NoSubclassMeta):
88
- raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
89
- return type.__new__(cls, name, bases, dict(classdict))
77
+ def __new__(cls, name, bases, classdict):
78
+ for b in bases:
79
+ if isinstance(b, NoSubclassMeta):
80
+ raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
81
+ return type.__new__(cls, name, bases, dict(classdict))
90
82
 
91
83
 
92
- class DelayedInitializer(metaclass=NoSubclassMeta):
93
- """
94
- DelayedInit initialization for parameter describers.
95
- """
84
+ class ParamDescriber(metaclass=NoSubclassMeta):
85
+ """
86
+ ParamDesc initialization for parameter describers.
87
+ """
96
88
 
97
- def __init__(self, cls: T, *desc_tuple, **desc_dict):
98
- self.cls: type = cls
89
+ def __init__(self, cls: T, *desc_tuple, **desc_dict):
90
+ self.cls: type = cls
99
91
 
100
- # arguments
101
- self.args = desc_tuple
102
- self.kwargs = desc_dict
92
+ # arguments
93
+ self.args = desc_tuple
94
+ self.kwargs = desc_dict
103
95
 
104
- # identifier
105
- self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
96
+ # identifier
97
+ self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
106
98
 
107
- def __call__(self, *args, **kwargs) -> T:
108
- return self.cls(*self.args, *args, **self.kwargs, **kwargs)
99
+ def __call__(self, *args, **kwargs) -> T:
100
+ return self.cls(*self.args, *args, **self.kwargs, **kwargs)
109
101
 
110
- def init(self, *args, **kwargs):
111
- return self.__call__(*args, **kwargs)
102
+ def init(self, *args, **kwargs):
103
+ return self.__call__(*args, **kwargs)
112
104
 
113
- def __instancecheck__(self, instance):
114
- if not isinstance(instance, DelayedInitializer):
115
- return False
116
- if not issubclass(instance.cls, self.cls):
117
- return False
118
- return True
105
+ def __instancecheck__(self, instance):
106
+ if not isinstance(instance, ParamDescriber):
107
+ return False
108
+ if not issubclass(instance.cls, self.cls):
109
+ return False
110
+ return True
119
111
 
120
- @classmethod
121
- def __class_getitem__(cls, item: type):
122
- return DelayedInitializer(item)
112
+ @classmethod
113
+ def __class_getitem__(cls, item: type):
114
+ return ParamDescriber(item)
123
115
 
124
- @property
125
- def identifier(self):
126
- return self._identifier
116
+ @property
117
+ def identifier(self):
118
+ return self._identifier
127
119
 
128
- @identifier.setter
129
- def identifier(self, value):
130
- raise AttributeError('Cannot set the identifier.')
120
+ @identifier.setter
121
+ def identifier(self, value):
122
+ raise AttributeError('Cannot set the identifier.')
131
123
 
132
124
 
133
125
  class AlignPost(Mixin):
134
- """
135
- Align post MixIn.
126
+ """
127
+ Align post MixIn.
136
128
 
137
- This class provides a ``align_post_input_add()`` function for
138
- add external currents.
139
- """
129
+ This class provides a ``align_post_input_add()`` function for
130
+ add external currents.
131
+ """
140
132
 
141
- def align_post_input_add(self, *args, **kwargs):
142
- raise NotImplementedError
133
+ def align_post_input_add(self, *args, **kwargs):
134
+ raise NotImplementedError
143
135
 
144
136
 
145
137
  class BindCondData(Mixin):
146
- """Bind temporary conductance data.
138
+ """Bind temporary conductance data.
147
139
 
148
140
 
149
- """
150
- _conductance: Optional
141
+ """
142
+ _conductance: Optional
151
143
 
152
- def bind_cond(self, conductance):
153
- self._conductance = conductance
144
+ def bind_cond(self, conductance):
145
+ self._conductance = conductance
154
146
 
155
- def unbind_cond(self):
156
- self._conductance = None
147
+ def unbind_cond(self):
148
+ self._conductance = None
157
149
 
158
150
 
159
151
  class UpdateReturn(Mixin):
160
152
 
161
- def update_return(self) -> PyTree:
162
- """
163
- The update function return of the model.
153
+ def update_return(self) -> PyTree:
154
+ """
155
+ The update function return of the model.
164
156
 
165
- It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
157
+ It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
166
158
 
167
- """
168
- raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
159
+ """
160
+ raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
169
161
 
170
- def update_return_info(self) -> PyTree:
171
- """
172
- The update return information of the model.
162
+ def update_return_info(self) -> PyTree:
163
+ """
164
+ The update return information of the model.
173
165
 
174
- It should be a pytree, with each element as a ``jax.Array``.
166
+ It should be a pytree, with each element as a ``jax.Array``.
175
167
 
176
- .. note::
177
- Should not include the batch axis and batch size.
178
- These information will be inferred from the ``mode`` attribute.
168
+ .. note::
169
+ Should not include the batch axis and batch in_size.
170
+ These information will be inferred from the ``mode`` attribute.
179
171
 
180
- """
181
- raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
172
+ """
173
+ raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
182
174
 
183
175
 
184
176
  class _MetaUnionType(type):
185
- def __new__(cls, name, bases, dct):
186
- if isinstance(bases, type):
187
- bases = (bases,)
188
- elif isinstance(bases, (list, tuple)):
189
- bases = tuple(bases)
190
- for base in bases:
191
- assert isinstance(base, type), f'Must be type. But got {base}'
192
- else:
193
- raise TypeError(f'Must be type. But got {bases}')
194
- return super().__new__(cls, name, bases, dct)
177
+ def __new__(cls, name, bases, dct):
178
+ if isinstance(bases, type):
179
+ bases = (bases,)
180
+ elif isinstance(bases, (list, tuple)):
181
+ bases = tuple(bases)
182
+ for base in bases:
183
+ assert isinstance(base, type), f'Must be type. But got {base}'
184
+ else:
185
+ raise TypeError(f'Must be type. But got {bases}')
186
+ return super().__new__(cls, name, bases, dct)
195
187
 
196
- def __instancecheck__(self, other):
197
- cls_of_other = other.__class__
198
- return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
188
+ def __instancecheck__(self, other):
189
+ cls_of_other = other.__class__
190
+ return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
199
191
 
200
- def __subclasscheck__(self, subclass):
201
- return all([issubclass(subclass, cls) for cls in self.__bases__])
192
+ def __subclasscheck__(self, subclass):
193
+ return all([issubclass(subclass, cls) for cls in self.__bases__])
202
194
 
203
195
 
204
196
  class _JointGenericAlias(_UnionGenericAlias, _root=True):
205
- def __subclasscheck__(self, subclass):
206
- return all([issubclass(subclass, cls) for cls in set(self.__args__)])
197
+ def __subclasscheck__(self, subclass):
198
+ return all([issubclass(subclass, cls) for cls in set(self.__args__)])
207
199
 
208
200
 
209
201
  @_SpecialForm
210
202
  def JointTypes(self, parameters):
211
- """Joint types; JointTypes[X, Y] means both X and Y.
203
+ """Joint types; JointTypes[X, Y] means both X and Y.
212
204
 
213
- To define a union, use e.g. Union[int, str].
205
+ To define a union, use e.g. Union[int, str].
214
206
 
215
- Details:
216
- - The arguments must be types and there must be at least one.
217
- - None as an argument is a special case and is replaced by `type(None)`.
218
- - Unions of unions are flattened, e.g.::
207
+ Details:
208
+ - The arguments must be types and there must be at least one.
209
+ - None as an argument is a special case and is replaced by `type(None)`.
210
+ - Unions of unions are flattened, e.g.::
219
211
 
220
- JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
212
+ JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
221
213
 
222
- - Unions of a single argument vanish, e.g.::
214
+ - Unions of a single argument vanish, e.g.::
223
215
 
224
- JointTypes[int] == int # The constructor actually returns int
216
+ JointTypes[int] == int # The constructor actually returns int
225
217
 
226
- - Redundant arguments are skipped, e.g.::
218
+ - Redundant arguments are skipped, e.g.::
227
219
 
228
- JointTypes[int, str, int] == JointTypes[int, str]
220
+ JointTypes[int, str, int] == JointTypes[int, str]
229
221
 
230
- - When comparing unions, the argument order is ignored, e.g.::
222
+ - When comparing unions, the argument order is ignored, e.g.::
231
223
 
232
- JointTypes[int, str] == JointTypes[str, int]
224
+ JointTypes[int, str] == JointTypes[str, int]
233
225
 
234
- - You cannot subclass or instantiate a JointTypes.
235
- - You can use Optional[X] as a shorthand for JointTypes[X, None].
236
- """
237
- if parameters == ():
238
- raise TypeError("Cannot take a Joint of no types.")
239
- if not isinstance(parameters, tuple):
240
- parameters = (parameters,)
241
- msg = "JointTypes[arg, ...]: each arg must be a type."
242
- parameters = tuple(_type_check(p, msg) for p in parameters)
243
- parameters = _remove_dups_flatten(parameters)
244
- if len(parameters) == 1:
245
- return parameters[0]
246
- if len(parameters) == 2 and type(None) in parameters:
247
- return _UnionGenericAlias(self, parameters, name="Optional")
248
- return _JointGenericAlias(self, parameters)
226
+ - You cannot subclass or instantiate a JointTypes.
227
+ - You can use Optional[X] as a shorthand for JointTypes[X, None].
228
+ """
229
+ if parameters == ():
230
+ raise TypeError("Cannot take a Joint of no types.")
231
+ if not isinstance(parameters, tuple):
232
+ parameters = (parameters,)
233
+ msg = "JointTypes[arg, ...]: each arg must be a type."
234
+ parameters = tuple(_type_check(p, msg) for p in parameters)
235
+ parameters = _remove_dups_flatten(parameters)
236
+ if len(parameters) == 1:
237
+ return parameters[0]
238
+ if len(parameters) == 2 and type(None) in parameters:
239
+ return _UnionGenericAlias(self, parameters, name="Optional")
240
+ return _JointGenericAlias(self, parameters)
249
241
 
250
242
 
251
243
  @_SpecialForm
252
244
  def OneOfTypes(self, parameters):
253
- """Sole type; OneOfTypes[X, Y] means either X or Y.
245
+ """Sole type; OneOfTypes[X, Y] means either X or Y.
254
246
 
255
- To define a union, use e.g. OneOfTypes[int, str]. Details:
256
- - The arguments must be types and there must be at least one.
257
- - None as an argument is a special case and is replaced by
258
- type(None).
259
- - Unions of unions are flattened, e.g.::
247
+ To define a union, use e.g. OneOfTypes[int, str]. Details:
248
+ - The arguments must be types and there must be at least one.
249
+ - None as an argument is a special case and is replaced by
250
+ type(None).
251
+ - Unions of unions are flattened, e.g.::
260
252
 
261
- assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
253
+ assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
262
254
 
263
- - Unions of a single argument vanish, e.g.::
255
+ - Unions of a single argument vanish, e.g.::
264
256
 
265
- assert OneOfTypes[int] == int # The constructor actually returns int
257
+ assert OneOfTypes[int] == int # The constructor actually returns int
266
258
 
267
- - Redundant arguments are skipped, e.g.::
259
+ - Redundant arguments are skipped, e.g.::
268
260
 
269
- assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
261
+ assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
270
262
 
271
- - When comparing unions, the argument order is ignored, e.g.::
263
+ - When comparing unions, the argument order is ignored, e.g.::
272
264
 
273
- assert OneOfTypes[int, str] == OneOfTypes[str, int]
265
+ assert OneOfTypes[int, str] == OneOfTypes[str, int]
274
266
 
275
- - You cannot subclass or instantiate a union.
276
- - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
277
- """
278
- if parameters == ():
279
- raise TypeError("Cannot take a Sole of no types.")
280
- if not isinstance(parameters, tuple):
281
- parameters = (parameters,)
282
- msg = "OneOfTypes[arg, ...]: each arg must be a type."
283
- parameters = tuple(_type_check(p, msg) for p in parameters)
284
- parameters = _remove_dups_flatten(parameters)
285
- if len(parameters) == 1:
286
- return parameters[0]
287
- if len(parameters) == 2 and type(None) in parameters:
288
- return _UnionGenericAlias(self, parameters, name="Optional")
289
- return _UnionGenericAlias(self, parameters)
267
+ - You cannot subclass or instantiate a union.
268
+ - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
269
+ """
270
+ if parameters == ():
271
+ raise TypeError("Cannot take a Sole of no types.")
272
+ if not isinstance(parameters, tuple):
273
+ parameters = (parameters,)
274
+ msg = "OneOfTypes[arg, ...]: each arg must be a type."
275
+ parameters = tuple(_type_check(p, msg) for p in parameters)
276
+ parameters = _remove_dups_flatten(parameters)
277
+ if len(parameters) == 1:
278
+ return parameters[0]
279
+ if len(parameters) == 2 and type(None) in parameters:
280
+ return _UnionGenericAlias(self, parameters, name="Optional")
281
+ return _UnionGenericAlias(self, parameters)
290
282
 
291
283
 
292
284
  class Mode(Mixin):
293
- """
294
- Base class for computation behaviors.
295
- """
296
-
297
- def __repr__(self):
298
- return self.__class__.__name__
299
-
300
- def __eq__(self, other: 'Mode'):
301
- assert isinstance(other, Mode)
302
- return other.__class__ == self.__class__
303
-
304
- def is_a(self, mode: type):
305
285
  """
306
- Check whether the mode is exactly the desired mode.
286
+ Base class for computation behaviors.
307
287
  """
308
- assert isinstance(mode, type), 'Must be a type.'
309
- return self.__class__ == mode
310
288
 
311
- def has(self, mode: type):
312
- """
313
- Check whether the mode is included in the desired mode.
314
- """
315
- assert isinstance(mode, type), 'Must be a type.'
316
- return isinstance(self, mode)
289
+ def __repr__(self):
290
+ return self.__class__.__name__
317
291
 
292
+ def __eq__(self, other: 'Mode'):
293
+ assert isinstance(other, Mode)
294
+ return other.__class__ == self.__class__
318
295
 
319
- class JointMode(Mode):
320
- """
321
- Joint mode.
322
- """
296
+ def is_a(self, mode: type):
297
+ """
298
+ Check whether the mode is exactly the desired mode.
299
+ """
300
+ assert isinstance(mode, type), 'Must be a type.'
301
+ return self.__class__ == mode
323
302
 
324
- def __init__(self, *modes: Mode):
325
- for m_ in modes:
326
- if not isinstance(m_, Mode):
327
- raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
328
- self.modes = tuple(modes)
329
- self.types = set([m.__class__ for m in modes])
303
+ def has(self, mode: type):
304
+ """
305
+ Check whether the mode is included in the desired mode.
306
+ """
307
+ assert isinstance(mode, type), 'Must be a type.'
308
+ return isinstance(self, mode)
330
309
 
331
- def __repr__(self):
332
- return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
333
310
 
334
- def has(self, mode: type):
335
- """
336
- Check whether the mode is included in the desired mode.
337
- """
338
- assert isinstance(mode, type), 'Must be a type.'
339
- return any([issubclass(cls, mode) for cls in self.types])
340
-
341
- def is_a(self, cls: type):
342
- """
343
- Check whether the mode is exactly the desired mode.
311
+ class JointMode(Mode):
344
312
  """
345
- return JointTypes[tuple(self.types)] == cls
346
-
347
- def __getattr__(self, item):
313
+ Joint mode.
348
314
  """
349
- Get the attribute from the mode.
350
315
 
351
- If the attribute is not found in the mode, then it will be searched in the base class.
352
- """
353
- if item in ['modes', 'types']:
354
- return super().__getattribute__(item)
355
- for m in self.modes:
356
- if hasattr(m, item):
357
- return getattr(m, item)
358
- return super().__getattribute__(item)
316
+ def __init__(self, *modes: Mode):
317
+ for m_ in modes:
318
+ if not isinstance(m_, Mode):
319
+ raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
320
+ self.modes = tuple(modes)
321
+ self.types = set([m.__class__ for m in modes])
322
+
323
+ def __repr__(self):
324
+ return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
325
+
326
+ def has(self, mode: type):
327
+ """
328
+ Check whether the mode is included in the desired mode.
329
+ """
330
+ assert isinstance(mode, type), 'Must be a type.'
331
+ return any([issubclass(cls, mode) for cls in self.types])
332
+
333
+ def is_a(self, cls: type):
334
+ """
335
+ Check whether the mode is exactly the desired mode.
336
+ """
337
+ return JointTypes[tuple(self.types)] == cls
338
+
339
+ def __getattr__(self, item):
340
+ """
341
+ Get the attribute from the mode.
342
+
343
+ If the attribute is not found in the mode, then it will be searched in the base class.
344
+ """
345
+ if item in ['modes', 'types']:
346
+ return super().__getattribute__(item)
347
+ for m in self.modes:
348
+ if hasattr(m, item):
349
+ return getattr(m, item)
350
+ return super().__getattribute__(item)
359
351
 
360
352
 
361
353
  class Batching(Mode):
362
- """Batching mode."""
354
+ """Batching mode."""
363
355
 
364
- def __init__(self, batch_size: int = 1, batch_axis: int = 0):
365
- self.batch_size = batch_size
366
- self.batch_axis = batch_axis
356
+ def __init__(self, batch_size: int = 1, batch_axis: int = 0):
357
+ self.batch_size = batch_size
358
+ self.batch_axis = batch_axis
367
359
 
368
- def __repr__(self):
369
- return f'{self.__class__.__name__}(size={self.batch_size}, axis={self.batch_axis})'
360
+ def __repr__(self):
361
+ return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
370
362
 
371
363
 
372
364
  class Training(Mode):
373
- """Training mode."""
374
- pass
365
+ """Training mode."""
366
+ pass