brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/mixin.py
CHANGED
@@ -15,360 +15,352 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from
|
19
|
-
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import (Sequence, Optional, TypeVar, _SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
|
20
21
|
|
21
22
|
from brainstate.typing import PyTree
|
22
23
|
|
23
24
|
T = TypeVar('T')
|
24
|
-
State = None
|
25
|
-
|
26
25
|
|
27
26
|
__all__ = [
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
27
|
+
'Mixin',
|
28
|
+
'ParamDesc',
|
29
|
+
'ParamDescriber',
|
30
|
+
'AlignPost',
|
31
|
+
'BindCondData',
|
32
|
+
'UpdateReturn',
|
33
|
+
|
34
|
+
# types
|
35
|
+
'JointTypes',
|
36
|
+
'OneOfTypes',
|
37
|
+
|
38
|
+
# behavior modes
|
39
|
+
'Mode',
|
40
|
+
'JointMode',
|
41
|
+
'Batching',
|
42
|
+
'Training',
|
44
43
|
]
|
45
44
|
|
46
45
|
|
47
|
-
def _get_state():
|
48
|
-
global State
|
49
|
-
if State is None:
|
50
|
-
from brainstate._state import State
|
51
|
-
return State
|
52
|
-
|
53
|
-
|
54
46
|
class Mixin(object):
|
55
|
-
|
47
|
+
"""Base Mixin object.
|
56
48
|
|
57
|
-
|
58
|
-
|
59
|
-
|
49
|
+
The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
|
50
|
+
"""
|
51
|
+
pass
|
60
52
|
|
61
53
|
|
62
|
-
class
|
63
|
-
|
64
|
-
|
54
|
+
class ParamDesc(Mixin):
|
55
|
+
"""
|
56
|
+
:py:class:`~.Mixin` indicates the function for describing initialization parameters.
|
65
57
|
|
66
|
-
|
67
|
-
|
58
|
+
This mixin enables the subclass has a classmethod ``desc``, which
|
59
|
+
produces an instance of :py:class:`~.ParamDescriber`.
|
68
60
|
|
69
|
-
|
70
|
-
|
61
|
+
Note this Mixin can be applied in any Python object.
|
62
|
+
"""
|
71
63
|
|
72
|
-
|
64
|
+
non_hashable_params: Optional[Sequence[str]] = None
|
73
65
|
|
74
|
-
|
75
|
-
|
76
|
-
|
66
|
+
@classmethod
|
67
|
+
def desc(cls, *args, **kwargs) -> 'ParamDescriber':
|
68
|
+
return ParamDescriber(cls, *args, **kwargs)
|
77
69
|
|
78
70
|
|
79
71
|
class HashableDict(dict):
|
80
|
-
|
81
|
-
|
72
|
+
def __hash__(self):
|
73
|
+
return hash(tuple(sorted(self.items())))
|
82
74
|
|
83
75
|
|
84
76
|
class NoSubclassMeta(type):
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
77
|
+
def __new__(cls, name, bases, classdict):
|
78
|
+
for b in bases:
|
79
|
+
if isinstance(b, NoSubclassMeta):
|
80
|
+
raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
|
81
|
+
return type.__new__(cls, name, bases, dict(classdict))
|
90
82
|
|
91
83
|
|
92
|
-
class
|
93
|
-
|
94
|
-
|
95
|
-
|
84
|
+
class ParamDescriber(metaclass=NoSubclassMeta):
|
85
|
+
"""
|
86
|
+
ParamDesc initialization for parameter describers.
|
87
|
+
"""
|
96
88
|
|
97
|
-
|
98
|
-
|
89
|
+
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
90
|
+
self.cls: type = cls
|
99
91
|
|
100
|
-
|
101
|
-
|
102
|
-
|
92
|
+
# arguments
|
93
|
+
self.args = desc_tuple
|
94
|
+
self.kwargs = desc_dict
|
103
95
|
|
104
|
-
|
105
|
-
|
96
|
+
# identifier
|
97
|
+
self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
|
106
98
|
|
107
|
-
|
108
|
-
|
99
|
+
def __call__(self, *args, **kwargs) -> T:
|
100
|
+
return self.cls(*self.args, *args, **self.kwargs, **kwargs)
|
109
101
|
|
110
|
-
|
111
|
-
|
102
|
+
def init(self, *args, **kwargs):
|
103
|
+
return self.__call__(*args, **kwargs)
|
112
104
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
105
|
+
def __instancecheck__(self, instance):
|
106
|
+
if not isinstance(instance, ParamDescriber):
|
107
|
+
return False
|
108
|
+
if not issubclass(instance.cls, self.cls):
|
109
|
+
return False
|
110
|
+
return True
|
119
111
|
|
120
|
-
|
121
|
-
|
122
|
-
|
112
|
+
@classmethod
|
113
|
+
def __class_getitem__(cls, item: type):
|
114
|
+
return ParamDescriber(item)
|
123
115
|
|
124
|
-
|
125
|
-
|
126
|
-
|
116
|
+
@property
|
117
|
+
def identifier(self):
|
118
|
+
return self._identifier
|
127
119
|
|
128
|
-
|
129
|
-
|
130
|
-
|
120
|
+
@identifier.setter
|
121
|
+
def identifier(self, value):
|
122
|
+
raise AttributeError('Cannot set the identifier.')
|
131
123
|
|
132
124
|
|
133
125
|
class AlignPost(Mixin):
|
134
|
-
|
135
|
-
|
126
|
+
"""
|
127
|
+
Align post MixIn.
|
136
128
|
|
137
|
-
|
138
|
-
|
139
|
-
|
129
|
+
This class provides a ``align_post_input_add()`` function for
|
130
|
+
add external currents.
|
131
|
+
"""
|
140
132
|
|
141
|
-
|
142
|
-
|
133
|
+
def align_post_input_add(self, *args, **kwargs):
|
134
|
+
raise NotImplementedError
|
143
135
|
|
144
136
|
|
145
137
|
class BindCondData(Mixin):
|
146
|
-
|
138
|
+
"""Bind temporary conductance data.
|
147
139
|
|
148
140
|
|
149
|
-
|
150
|
-
|
141
|
+
"""
|
142
|
+
_conductance: Optional
|
151
143
|
|
152
|
-
|
153
|
-
|
144
|
+
def bind_cond(self, conductance):
|
145
|
+
self._conductance = conductance
|
154
146
|
|
155
|
-
|
156
|
-
|
147
|
+
def unbind_cond(self):
|
148
|
+
self._conductance = None
|
157
149
|
|
158
150
|
|
159
151
|
class UpdateReturn(Mixin):
|
160
152
|
|
161
|
-
|
162
|
-
|
163
|
-
|
153
|
+
def update_return(self) -> PyTree:
|
154
|
+
"""
|
155
|
+
The update function return of the model.
|
164
156
|
|
165
|
-
|
157
|
+
It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
|
166
158
|
|
167
|
-
|
168
|
-
|
159
|
+
"""
|
160
|
+
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
169
161
|
|
170
|
-
|
171
|
-
|
172
|
-
|
162
|
+
def update_return_info(self) -> PyTree:
|
163
|
+
"""
|
164
|
+
The update return information of the model.
|
173
165
|
|
174
|
-
|
166
|
+
It should be a pytree, with each element as a ``jax.Array``.
|
175
167
|
|
176
|
-
|
177
|
-
|
178
|
-
|
168
|
+
.. note::
|
169
|
+
Should not include the batch axis and batch in_size.
|
170
|
+
These information will be inferred from the ``mode`` attribute.
|
179
171
|
|
180
|
-
|
181
|
-
|
172
|
+
"""
|
173
|
+
raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
|
182
174
|
|
183
175
|
|
184
176
|
class _MetaUnionType(type):
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
177
|
+
def __new__(cls, name, bases, dct):
|
178
|
+
if isinstance(bases, type):
|
179
|
+
bases = (bases,)
|
180
|
+
elif isinstance(bases, (list, tuple)):
|
181
|
+
bases = tuple(bases)
|
182
|
+
for base in bases:
|
183
|
+
assert isinstance(base, type), f'Must be type. But got {base}'
|
184
|
+
else:
|
185
|
+
raise TypeError(f'Must be type. But got {bases}')
|
186
|
+
return super().__new__(cls, name, bases, dct)
|
195
187
|
|
196
|
-
|
197
|
-
|
198
|
-
|
188
|
+
def __instancecheck__(self, other):
|
189
|
+
cls_of_other = other.__class__
|
190
|
+
return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
|
199
191
|
|
200
|
-
|
201
|
-
|
192
|
+
def __subclasscheck__(self, subclass):
|
193
|
+
return all([issubclass(subclass, cls) for cls in self.__bases__])
|
202
194
|
|
203
195
|
|
204
196
|
class _JointGenericAlias(_UnionGenericAlias, _root=True):
|
205
|
-
|
206
|
-
|
197
|
+
def __subclasscheck__(self, subclass):
|
198
|
+
return all([issubclass(subclass, cls) for cls in set(self.__args__)])
|
207
199
|
|
208
200
|
|
209
201
|
@_SpecialForm
|
210
202
|
def JointTypes(self, parameters):
|
211
|
-
|
203
|
+
"""Joint types; JointTypes[X, Y] means both X and Y.
|
212
204
|
|
213
|
-
|
205
|
+
To define a union, use e.g. Union[int, str].
|
214
206
|
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
207
|
+
Details:
|
208
|
+
- The arguments must be types and there must be at least one.
|
209
|
+
- None as an argument is a special case and is replaced by `type(None)`.
|
210
|
+
- Unions of unions are flattened, e.g.::
|
219
211
|
|
220
|
-
|
212
|
+
JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
|
221
213
|
|
222
|
-
|
214
|
+
- Unions of a single argument vanish, e.g.::
|
223
215
|
|
224
|
-
|
216
|
+
JointTypes[int] == int # The constructor actually returns int
|
225
217
|
|
226
|
-
|
218
|
+
- Redundant arguments are skipped, e.g.::
|
227
219
|
|
228
|
-
|
220
|
+
JointTypes[int, str, int] == JointTypes[int, str]
|
229
221
|
|
230
|
-
|
222
|
+
- When comparing unions, the argument order is ignored, e.g.::
|
231
223
|
|
232
|
-
|
224
|
+
JointTypes[int, str] == JointTypes[str, int]
|
233
225
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
226
|
+
- You cannot subclass or instantiate a JointTypes.
|
227
|
+
- You can use Optional[X] as a shorthand for JointTypes[X, None].
|
228
|
+
"""
|
229
|
+
if parameters == ():
|
230
|
+
raise TypeError("Cannot take a Joint of no types.")
|
231
|
+
if not isinstance(parameters, tuple):
|
232
|
+
parameters = (parameters,)
|
233
|
+
msg = "JointTypes[arg, ...]: each arg must be a type."
|
234
|
+
parameters = tuple(_type_check(p, msg) for p in parameters)
|
235
|
+
parameters = _remove_dups_flatten(parameters)
|
236
|
+
if len(parameters) == 1:
|
237
|
+
return parameters[0]
|
238
|
+
if len(parameters) == 2 and type(None) in parameters:
|
239
|
+
return _UnionGenericAlias(self, parameters, name="Optional")
|
240
|
+
return _JointGenericAlias(self, parameters)
|
249
241
|
|
250
242
|
|
251
243
|
@_SpecialForm
|
252
244
|
def OneOfTypes(self, parameters):
|
253
|
-
|
245
|
+
"""Sole type; OneOfTypes[X, Y] means either X or Y.
|
254
246
|
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
247
|
+
To define a union, use e.g. OneOfTypes[int, str]. Details:
|
248
|
+
- The arguments must be types and there must be at least one.
|
249
|
+
- None as an argument is a special case and is replaced by
|
250
|
+
type(None).
|
251
|
+
- Unions of unions are flattened, e.g.::
|
260
252
|
|
261
|
-
|
253
|
+
assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
|
262
254
|
|
263
|
-
|
255
|
+
- Unions of a single argument vanish, e.g.::
|
264
256
|
|
265
|
-
|
257
|
+
assert OneOfTypes[int] == int # The constructor actually returns int
|
266
258
|
|
267
|
-
|
259
|
+
- Redundant arguments are skipped, e.g.::
|
268
260
|
|
269
|
-
|
261
|
+
assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
|
270
262
|
|
271
|
-
|
263
|
+
- When comparing unions, the argument order is ignored, e.g.::
|
272
264
|
|
273
|
-
|
265
|
+
assert OneOfTypes[int, str] == OneOfTypes[str, int]
|
274
266
|
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
267
|
+
- You cannot subclass or instantiate a union.
|
268
|
+
- You can use Optional[X] as a shorthand for OneOfTypes[X, None].
|
269
|
+
"""
|
270
|
+
if parameters == ():
|
271
|
+
raise TypeError("Cannot take a Sole of no types.")
|
272
|
+
if not isinstance(parameters, tuple):
|
273
|
+
parameters = (parameters,)
|
274
|
+
msg = "OneOfTypes[arg, ...]: each arg must be a type."
|
275
|
+
parameters = tuple(_type_check(p, msg) for p in parameters)
|
276
|
+
parameters = _remove_dups_flatten(parameters)
|
277
|
+
if len(parameters) == 1:
|
278
|
+
return parameters[0]
|
279
|
+
if len(parameters) == 2 and type(None) in parameters:
|
280
|
+
return _UnionGenericAlias(self, parameters, name="Optional")
|
281
|
+
return _UnionGenericAlias(self, parameters)
|
290
282
|
|
291
283
|
|
292
284
|
class Mode(Mixin):
|
293
|
-
"""
|
294
|
-
Base class for computation behaviors.
|
295
|
-
"""
|
296
|
-
|
297
|
-
def __repr__(self):
|
298
|
-
return self.__class__.__name__
|
299
|
-
|
300
|
-
def __eq__(self, other: 'Mode'):
|
301
|
-
assert isinstance(other, Mode)
|
302
|
-
return other.__class__ == self.__class__
|
303
|
-
|
304
|
-
def is_a(self, mode: type):
|
305
285
|
"""
|
306
|
-
|
286
|
+
Base class for computation behaviors.
|
307
287
|
"""
|
308
|
-
assert isinstance(mode, type), 'Must be a type.'
|
309
|
-
return self.__class__ == mode
|
310
288
|
|
311
|
-
|
312
|
-
|
313
|
-
Check whether the mode is included in the desired mode.
|
314
|
-
"""
|
315
|
-
assert isinstance(mode, type), 'Must be a type.'
|
316
|
-
return isinstance(self, mode)
|
289
|
+
def __repr__(self):
|
290
|
+
return self.__class__.__name__
|
317
291
|
|
292
|
+
def __eq__(self, other: 'Mode'):
|
293
|
+
assert isinstance(other, Mode)
|
294
|
+
return other.__class__ == self.__class__
|
318
295
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
296
|
+
def is_a(self, mode: type):
|
297
|
+
"""
|
298
|
+
Check whether the mode is exactly the desired mode.
|
299
|
+
"""
|
300
|
+
assert isinstance(mode, type), 'Must be a type.'
|
301
|
+
return self.__class__ == mode
|
323
302
|
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
303
|
+
def has(self, mode: type):
|
304
|
+
"""
|
305
|
+
Check whether the mode is included in the desired mode.
|
306
|
+
"""
|
307
|
+
assert isinstance(mode, type), 'Must be a type.'
|
308
|
+
return isinstance(self, mode)
|
330
309
|
|
331
|
-
def __repr__(self):
|
332
|
-
return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
|
333
310
|
|
334
|
-
|
335
|
-
"""
|
336
|
-
Check whether the mode is included in the desired mode.
|
337
|
-
"""
|
338
|
-
assert isinstance(mode, type), 'Must be a type.'
|
339
|
-
return any([issubclass(cls, mode) for cls in self.types])
|
340
|
-
|
341
|
-
def is_a(self, cls: type):
|
342
|
-
"""
|
343
|
-
Check whether the mode is exactly the desired mode.
|
311
|
+
class JointMode(Mode):
|
344
312
|
"""
|
345
|
-
|
346
|
-
|
347
|
-
def __getattr__(self, item):
|
313
|
+
Joint mode.
|
348
314
|
"""
|
349
|
-
Get the attribute from the mode.
|
350
315
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
316
|
+
def __init__(self, *modes: Mode):
|
317
|
+
for m_ in modes:
|
318
|
+
if not isinstance(m_, Mode):
|
319
|
+
raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
|
320
|
+
self.modes = tuple(modes)
|
321
|
+
self.types = set([m.__class__ for m in modes])
|
322
|
+
|
323
|
+
def __repr__(self):
|
324
|
+
return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
|
325
|
+
|
326
|
+
def has(self, mode: type):
|
327
|
+
"""
|
328
|
+
Check whether the mode is included in the desired mode.
|
329
|
+
"""
|
330
|
+
assert isinstance(mode, type), 'Must be a type.'
|
331
|
+
return any([issubclass(cls, mode) for cls in self.types])
|
332
|
+
|
333
|
+
def is_a(self, cls: type):
|
334
|
+
"""
|
335
|
+
Check whether the mode is exactly the desired mode.
|
336
|
+
"""
|
337
|
+
return JointTypes[tuple(self.types)] == cls
|
338
|
+
|
339
|
+
def __getattr__(self, item):
|
340
|
+
"""
|
341
|
+
Get the attribute from the mode.
|
342
|
+
|
343
|
+
If the attribute is not found in the mode, then it will be searched in the base class.
|
344
|
+
"""
|
345
|
+
if item in ['modes', 'types']:
|
346
|
+
return super().__getattribute__(item)
|
347
|
+
for m in self.modes:
|
348
|
+
if hasattr(m, item):
|
349
|
+
return getattr(m, item)
|
350
|
+
return super().__getattribute__(item)
|
359
351
|
|
360
352
|
|
361
353
|
class Batching(Mode):
|
362
|
-
|
354
|
+
"""Batching mode."""
|
363
355
|
|
364
|
-
|
365
|
-
|
366
|
-
|
356
|
+
def __init__(self, batch_size: int = 1, batch_axis: int = 0):
|
357
|
+
self.batch_size = batch_size
|
358
|
+
self.batch_axis = batch_axis
|
367
359
|
|
368
|
-
|
369
|
-
|
360
|
+
def __repr__(self):
|
361
|
+
return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
|
370
362
|
|
371
363
|
|
372
364
|
class Training(Mode):
|
373
|
-
|
374
|
-
|
365
|
+
"""Training mode."""
|
366
|
+
pass
|