brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
brainstate/mixin.py
CHANGED
@@ -1,363 +1,365 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import (
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
'
|
29
|
-
'
|
30
|
-
|
31
|
-
|
32
|
-
'
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
'
|
37
|
-
|
38
|
-
|
39
|
-
'
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
"""
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
self.
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
self.
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
def
|
116
|
-
return self.
|
117
|
-
|
118
|
-
def
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
"""
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
def
|
161
|
-
self._conductance =
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
bases =
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
if
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
if len(parameters) ==
|
236
|
-
return
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
if
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
if len(parameters) ==
|
277
|
-
return
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
"""
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
"""
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
"""
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
"""
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
"""
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
"""
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import (
|
19
|
+
Sequence, Optional, TypeVar, _SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias
|
20
|
+
)
|
21
|
+
|
22
|
+
import jax
|
23
|
+
|
24
|
+
T = TypeVar('T')
|
25
|
+
ArrayLike = jax.typing.ArrayLike
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'Mixin',
|
29
|
+
'ParamDesc',
|
30
|
+
'ParamDescriber',
|
31
|
+
'AlignPost',
|
32
|
+
'BindCondData',
|
33
|
+
|
34
|
+
# types
|
35
|
+
'JointTypes',
|
36
|
+
'OneOfTypes',
|
37
|
+
|
38
|
+
# behavior modes
|
39
|
+
'Mode',
|
40
|
+
'JointMode',
|
41
|
+
'Batching',
|
42
|
+
'Training',
|
43
|
+
]
|
44
|
+
|
45
|
+
|
46
|
+
def hashable(x):
|
47
|
+
try:
|
48
|
+
hash(x)
|
49
|
+
return True
|
50
|
+
except TypeError:
|
51
|
+
return False
|
52
|
+
|
53
|
+
|
54
|
+
class Mixin(object):
|
55
|
+
"""Base Mixin object.
|
56
|
+
|
57
|
+
The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
|
58
|
+
"""
|
59
|
+
pass
|
60
|
+
|
61
|
+
|
62
|
+
class ParamDesc(Mixin):
|
63
|
+
"""
|
64
|
+
:py:class:`~.Mixin` indicates the function for describing initialization parameters.
|
65
|
+
|
66
|
+
This mixin enables the subclass has a classmethod ``desc``, which
|
67
|
+
produces an instance of :py:class:`~.ParamDescriber`.
|
68
|
+
|
69
|
+
Note this Mixin can be applied in any Python object.
|
70
|
+
"""
|
71
|
+
|
72
|
+
non_hashable_params: Optional[Sequence[str]] = None
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def desc(cls, *args, **kwargs) -> 'ParamDescriber':
|
76
|
+
return ParamDescriber(cls, *args, **kwargs)
|
77
|
+
|
78
|
+
|
79
|
+
class HashableDict(dict):
|
80
|
+
def __init__(self, the_dict: dict):
|
81
|
+
out = dict()
|
82
|
+
for k, v in the_dict.items():
|
83
|
+
if not hashable(v):
|
84
|
+
v = str(v) # convert to string if not hashable
|
85
|
+
out[k] = v
|
86
|
+
super().__init__(out)
|
87
|
+
|
88
|
+
def __hash__(self):
|
89
|
+
return hash(tuple(sorted(self.items())))
|
90
|
+
|
91
|
+
|
92
|
+
class NoSubclassMeta(type):
|
93
|
+
def __new__(cls, name, bases, classdict):
|
94
|
+
for b in bases:
|
95
|
+
if isinstance(b, NoSubclassMeta):
|
96
|
+
raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
|
97
|
+
return type.__new__(cls, name, bases, dict(classdict))
|
98
|
+
|
99
|
+
|
100
|
+
class ParamDescriber(metaclass=NoSubclassMeta):
|
101
|
+
"""
|
102
|
+
ParamDesc initialization for parameter describers.
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
106
|
+
self.cls: type = cls
|
107
|
+
|
108
|
+
# arguments
|
109
|
+
self.args = desc_tuple
|
110
|
+
self.kwargs = desc_dict
|
111
|
+
|
112
|
+
# identifier
|
113
|
+
self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
|
114
|
+
|
115
|
+
def __call__(self, *args, **kwargs) -> T:
|
116
|
+
return self.cls(*self.args, *args, **self.kwargs, **kwargs)
|
117
|
+
|
118
|
+
def init(self, *args, **kwargs):
|
119
|
+
return self.__call__(*args, **kwargs)
|
120
|
+
|
121
|
+
def __instancecheck__(self, instance):
|
122
|
+
if not isinstance(instance, ParamDescriber):
|
123
|
+
return False
|
124
|
+
if not issubclass(instance.cls, self.cls):
|
125
|
+
return False
|
126
|
+
return True
|
127
|
+
|
128
|
+
@classmethod
|
129
|
+
def __class_getitem__(cls, item: type):
|
130
|
+
return ParamDescriber(item)
|
131
|
+
|
132
|
+
@property
|
133
|
+
def identifier(self):
|
134
|
+
return self._identifier
|
135
|
+
|
136
|
+
@identifier.setter
|
137
|
+
def identifier(self, value: ArrayLike):
|
138
|
+
raise AttributeError('Cannot set the identifier.')
|
139
|
+
|
140
|
+
|
141
|
+
class AlignPost(Mixin):
|
142
|
+
"""
|
143
|
+
Align post MixIn.
|
144
|
+
|
145
|
+
This class provides a ``align_post_input_add()`` function for
|
146
|
+
add external currents.
|
147
|
+
"""
|
148
|
+
|
149
|
+
def align_post_input_add(self, *args, **kwargs):
|
150
|
+
raise NotImplementedError
|
151
|
+
|
152
|
+
|
153
|
+
class BindCondData(Mixin):
|
154
|
+
"""Bind temporary conductance data.
|
155
|
+
|
156
|
+
|
157
|
+
"""
|
158
|
+
_conductance: Optional
|
159
|
+
|
160
|
+
def bind_cond(self, conductance):
|
161
|
+
self._conductance = conductance
|
162
|
+
|
163
|
+
def unbind_cond(self):
|
164
|
+
self._conductance = None
|
165
|
+
|
166
|
+
|
167
|
+
def not_implemented(func):
|
168
|
+
def wrapper(*args, **kwargs):
|
169
|
+
raise NotImplementedError(f'{func.__name__} is not implemented.')
|
170
|
+
|
171
|
+
wrapper.not_implemented = True
|
172
|
+
return wrapper
|
173
|
+
|
174
|
+
|
175
|
+
class _MetaUnionType(type):
|
176
|
+
def __new__(cls, name, bases, dct):
|
177
|
+
if isinstance(bases, type):
|
178
|
+
bases = (bases,)
|
179
|
+
elif isinstance(bases, (list, tuple)):
|
180
|
+
bases = tuple(bases)
|
181
|
+
for base in bases:
|
182
|
+
assert isinstance(base, type), f'Must be type. But got {base}'
|
183
|
+
else:
|
184
|
+
raise TypeError(f'Must be type. But got {bases}')
|
185
|
+
return super().__new__(cls, name, bases, dct)
|
186
|
+
|
187
|
+
def __instancecheck__(self, other):
|
188
|
+
cls_of_other = other.__class__
|
189
|
+
return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
|
190
|
+
|
191
|
+
def __subclasscheck__(self, subclass):
|
192
|
+
return all([issubclass(subclass, cls) for cls in self.__bases__])
|
193
|
+
|
194
|
+
|
195
|
+
class _JointGenericAlias(_UnionGenericAlias, _root=True):
|
196
|
+
def __subclasscheck__(self, subclass):
|
197
|
+
return all([issubclass(subclass, cls) for cls in set(self.__args__)])
|
198
|
+
|
199
|
+
|
200
|
+
@_SpecialForm
|
201
|
+
def JointTypes(self, parameters):
|
202
|
+
"""Joint types; JointTypes[X, Y] means both X and Y.
|
203
|
+
|
204
|
+
To define a union, use e.g. Union[int, str].
|
205
|
+
|
206
|
+
Details:
|
207
|
+
- The arguments must be types and there must be at least one.
|
208
|
+
- None as an argument is a special case and is replaced by `type(None)`.
|
209
|
+
- Unions of unions are flattened, e.g.::
|
210
|
+
|
211
|
+
JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
|
212
|
+
|
213
|
+
- Unions of a single argument vanish, e.g.::
|
214
|
+
|
215
|
+
JointTypes[int] == int # The constructor actually returns int
|
216
|
+
|
217
|
+
- Redundant arguments are skipped, e.g.::
|
218
|
+
|
219
|
+
JointTypes[int, str, int] == JointTypes[int, str]
|
220
|
+
|
221
|
+
- When comparing unions, the argument order is ignored, e.g.::
|
222
|
+
|
223
|
+
JointTypes[int, str] == JointTypes[str, int]
|
224
|
+
|
225
|
+
- You cannot subclass or instantiate a JointTypes.
|
226
|
+
- You can use Optional[X] as a shorthand for JointTypes[X, None].
|
227
|
+
"""
|
228
|
+
if parameters == ():
|
229
|
+
raise TypeError("Cannot take a Joint of no types.")
|
230
|
+
if not isinstance(parameters, tuple):
|
231
|
+
parameters = (parameters,)
|
232
|
+
msg = "JointTypes[arg, ...]: each arg must be a type."
|
233
|
+
parameters = tuple(_type_check(p, msg) for p in parameters)
|
234
|
+
parameters = _remove_dups_flatten(parameters)
|
235
|
+
if len(parameters) == 1:
|
236
|
+
return parameters[0]
|
237
|
+
if len(parameters) == 2 and type(None) in parameters:
|
238
|
+
return _UnionGenericAlias(self, parameters, name="Optional")
|
239
|
+
return _JointGenericAlias(self, parameters)
|
240
|
+
|
241
|
+
|
242
|
+
@_SpecialForm
|
243
|
+
def OneOfTypes(self, parameters):
|
244
|
+
"""Sole type; OneOfTypes[X, Y] means either X or Y.
|
245
|
+
|
246
|
+
To define a union, use e.g. OneOfTypes[int, str]. Details:
|
247
|
+
- The arguments must be types and there must be at least one.
|
248
|
+
- None as an argument is a special case and is replaced by
|
249
|
+
type(None).
|
250
|
+
- Unions of unions are flattened, e.g.::
|
251
|
+
|
252
|
+
assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
|
253
|
+
|
254
|
+
- Unions of a single argument vanish, e.g.::
|
255
|
+
|
256
|
+
assert OneOfTypes[int] == int # The constructor actually returns int
|
257
|
+
|
258
|
+
- Redundant arguments are skipped, e.g.::
|
259
|
+
|
260
|
+
assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
|
261
|
+
|
262
|
+
- When comparing unions, the argument order is ignored, e.g.::
|
263
|
+
|
264
|
+
assert OneOfTypes[int, str] == OneOfTypes[str, int]
|
265
|
+
|
266
|
+
- You cannot subclass or instantiate a union.
|
267
|
+
- You can use Optional[X] as a shorthand for OneOfTypes[X, None].
|
268
|
+
"""
|
269
|
+
if parameters == ():
|
270
|
+
raise TypeError("Cannot take a Sole of no types.")
|
271
|
+
if not isinstance(parameters, tuple):
|
272
|
+
parameters = (parameters,)
|
273
|
+
msg = "OneOfTypes[arg, ...]: each arg must be a type."
|
274
|
+
parameters = tuple(_type_check(p, msg) for p in parameters)
|
275
|
+
parameters = _remove_dups_flatten(parameters)
|
276
|
+
if len(parameters) == 1:
|
277
|
+
return parameters[0]
|
278
|
+
if len(parameters) == 2 and type(None) in parameters:
|
279
|
+
return _UnionGenericAlias(self, parameters, name="Optional")
|
280
|
+
return _UnionGenericAlias(self, parameters)
|
281
|
+
|
282
|
+
|
283
|
+
class Mode(Mixin):
|
284
|
+
"""
|
285
|
+
Base class for computation behaviors.
|
286
|
+
"""
|
287
|
+
|
288
|
+
def __repr__(self):
|
289
|
+
return self.__class__.__name__
|
290
|
+
|
291
|
+
def __eq__(self, other: 'Mode'):
|
292
|
+
assert isinstance(other, Mode)
|
293
|
+
return other.__class__ == self.__class__
|
294
|
+
|
295
|
+
def is_a(self, mode: type):
|
296
|
+
"""
|
297
|
+
Check whether the mode is exactly the desired mode.
|
298
|
+
"""
|
299
|
+
assert isinstance(mode, type), 'Must be a type.'
|
300
|
+
return self.__class__ == mode
|
301
|
+
|
302
|
+
def has(self, mode: type):
|
303
|
+
"""
|
304
|
+
Check whether the mode is included in the desired mode.
|
305
|
+
"""
|
306
|
+
assert isinstance(mode, type), 'Must be a type.'
|
307
|
+
return isinstance(self, mode)
|
308
|
+
|
309
|
+
|
310
|
+
class JointMode(Mode):
|
311
|
+
"""
|
312
|
+
Joint mode.
|
313
|
+
"""
|
314
|
+
|
315
|
+
def __init__(self, *modes: Mode):
|
316
|
+
for m_ in modes:
|
317
|
+
if not isinstance(m_, Mode):
|
318
|
+
raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
|
319
|
+
self.modes = tuple(modes)
|
320
|
+
self.types = set([m.__class__ for m in modes])
|
321
|
+
|
322
|
+
def __repr__(self):
|
323
|
+
return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
|
324
|
+
|
325
|
+
def has(self, mode: type):
|
326
|
+
"""
|
327
|
+
Check whether the mode is included in the desired mode.
|
328
|
+
"""
|
329
|
+
assert isinstance(mode, type), 'Must be a type.'
|
330
|
+
return any([issubclass(cls, mode) for cls in self.types])
|
331
|
+
|
332
|
+
def is_a(self, cls: type):
|
333
|
+
"""
|
334
|
+
Check whether the mode is exactly the desired mode.
|
335
|
+
"""
|
336
|
+
return JointTypes[tuple(self.types)] == cls
|
337
|
+
|
338
|
+
def __getattr__(self, item):
|
339
|
+
"""
|
340
|
+
Get the attribute from the mode.
|
341
|
+
|
342
|
+
If the attribute is not found in the mode, then it will be searched in the base class.
|
343
|
+
"""
|
344
|
+
if item in ['modes', 'types']:
|
345
|
+
return super().__getattribute__(item)
|
346
|
+
for m in self.modes:
|
347
|
+
if hasattr(m, item):
|
348
|
+
return getattr(m, item)
|
349
|
+
return super().__getattribute__(item)
|
350
|
+
|
351
|
+
|
352
|
+
class Batching(Mode):
|
353
|
+
"""Batching mode."""
|
354
|
+
|
355
|
+
def __init__(self, batch_size: int = 1, batch_axis: int = 0):
|
356
|
+
self.batch_size = batch_size
|
357
|
+
self.batch_axis = batch_axis
|
358
|
+
|
359
|
+
def __repr__(self):
|
360
|
+
return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
|
361
|
+
|
362
|
+
|
363
|
+
class Training(Mode):
|
364
|
+
"""Training mode."""
|
365
|
+
pass
|