brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -1,563 +1,563 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import contextlib
|
19
|
-
import dataclasses
|
20
|
-
import functools
|
21
|
-
import os
|
22
|
-
import re
|
23
|
-
import threading
|
24
|
-
from collections import defaultdict
|
25
|
-
from typing import Any, Callable, Dict, Hashable
|
26
|
-
|
27
|
-
import numpy as np
|
28
|
-
from jax import config, devices, numpy as jnp
|
29
|
-
from jax.typing import DTypeLike
|
30
|
-
|
31
|
-
from .mixin import Mode
|
32
|
-
|
33
|
-
__all__ = [
|
34
|
-
# functions for environment settings
|
35
|
-
'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
|
36
|
-
# functions for getting default behaviors
|
37
|
-
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
|
38
|
-
# functions for default data types
|
39
|
-
'dftype', 'ditype', 'dutype', 'dctype',
|
40
|
-
# others
|
41
|
-
'tolerance', 'register_default_behavior',
|
42
|
-
]
|
43
|
-
|
44
|
-
# Default, there are several shared arguments in the global context.
|
45
|
-
I = 'i' # the index of the current computation.
|
46
|
-
T = 't' # the current time of the current computation.
|
47
|
-
JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
|
48
|
-
FIT = 'fit' # whether to fit the model.
|
49
|
-
|
50
|
-
|
51
|
-
@dataclasses.dataclass
|
52
|
-
class DefaultContext(threading.local):
|
53
|
-
# default environment settings
|
54
|
-
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
55
|
-
# current environment settings
|
56
|
-
contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
57
|
-
# environment functions
|
58
|
-
functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
59
|
-
|
60
|
-
|
61
|
-
DFAULT = DefaultContext()
|
62
|
-
_NOT_PROVIDE = object()
|
63
|
-
|
64
|
-
|
65
|
-
@contextlib.contextmanager
|
66
|
-
def context(**kwargs):
|
67
|
-
r"""
|
68
|
-
Context-manager that sets a computing environment for brain dynamics computation.
|
69
|
-
|
70
|
-
In BrainPy, there are several basic computation settings when constructing models,
|
71
|
-
including ``mode`` for controlling model computing behavior, ``dt`` for numerical
|
72
|
-
integration, ``int_`` for integer precision, and ``float_`` for floating precision.
|
73
|
-
:py:class:`~.environment`` provides a context for model construction and
|
74
|
-
computation. In this temporal environment, models are constructed with the given
|
75
|
-
``mode``, ``dt``, ``int_``, etc., environment settings.
|
76
|
-
|
77
|
-
For instance::
|
78
|
-
|
79
|
-
>>> import brainstate as brainstate
|
80
|
-
>>> with brainstate.environ.context(dt=0.1) as env:
|
81
|
-
... dt = brainstate.environ.get('dt')
|
82
|
-
... print(env)
|
83
|
-
|
84
|
-
"""
|
85
|
-
if 'platform' in kwargs:
|
86
|
-
raise ValueError('\n'
|
87
|
-
'Cannot set platform in "context" environment. \n'
|
88
|
-
'You should set platform in the global environment by "set_platform()" or "set()".')
|
89
|
-
if 'host_device_count' in kwargs:
|
90
|
-
raise ValueError('Cannot set host_device_count in environment context. '
|
91
|
-
'Please use set_host_device_count() or set() for the global setting.')
|
92
|
-
|
93
|
-
if 'precision' in kwargs:
|
94
|
-
last_precision = _get_precision()
|
95
|
-
_set_jax_precision(kwargs['precision'])
|
96
|
-
|
97
|
-
try:
|
98
|
-
for k, v in kwargs.items():
|
99
|
-
|
100
|
-
# update the current environment
|
101
|
-
DFAULT.contexts[k].append(v)
|
102
|
-
|
103
|
-
# restore the environment functions
|
104
|
-
if k in DFAULT.functions:
|
105
|
-
DFAULT.functions[k](v)
|
106
|
-
|
107
|
-
# yield the current all environment information
|
108
|
-
yield all()
|
109
|
-
finally:
|
110
|
-
|
111
|
-
for k, v in kwargs.items():
|
112
|
-
|
113
|
-
# restore the current environment
|
114
|
-
DFAULT.contexts[k].pop()
|
115
|
-
|
116
|
-
# restore the environment functions
|
117
|
-
if k in DFAULT.functions:
|
118
|
-
DFAULT.functions[k](get(k))
|
119
|
-
|
120
|
-
if 'precision' in kwargs:
|
121
|
-
_set_jax_precision(last_precision)
|
122
|
-
|
123
|
-
|
124
|
-
def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
|
125
|
-
"""
|
126
|
-
Get one of the default computation environment.
|
127
|
-
|
128
|
-
Returns
|
129
|
-
-------
|
130
|
-
item: Any
|
131
|
-
The default computation environment.
|
132
|
-
"""
|
133
|
-
if key == 'platform':
|
134
|
-
return get_platform()
|
135
|
-
|
136
|
-
if key == 'host_device_count':
|
137
|
-
return get_host_device_count()
|
138
|
-
|
139
|
-
if key in DFAULT.contexts:
|
140
|
-
if len(DFAULT.contexts[key]) > 0:
|
141
|
-
return DFAULT.contexts[key][-1]
|
142
|
-
if key in DFAULT.settings:
|
143
|
-
return DFAULT.settings[key]
|
144
|
-
|
145
|
-
if default is _NOT_PROVIDE:
|
146
|
-
if desc is not None:
|
147
|
-
raise KeyError(
|
148
|
-
f"'{key}' is not found in the context. \n"
|
149
|
-
f"You can set it by `brainstate.share.context({key}=value)` "
|
150
|
-
f"locally or `brainstate.share.set({key}=value)` globally. \n"
|
151
|
-
f"Description: {desc}"
|
152
|
-
)
|
153
|
-
else:
|
154
|
-
raise KeyError(
|
155
|
-
f"'{key}' is not found in the context. \n"
|
156
|
-
f"You can set it by `brainstate.share.context({key}=value)` "
|
157
|
-
f"locally or `brainstate.share.set({key}=value)` globally."
|
158
|
-
)
|
159
|
-
return default
|
160
|
-
|
161
|
-
|
162
|
-
def all() -> dict:
|
163
|
-
"""
|
164
|
-
Get all the current default computation environment.
|
165
|
-
|
166
|
-
Returns
|
167
|
-
-------
|
168
|
-
r: dict
|
169
|
-
The current default computation environment.
|
170
|
-
"""
|
171
|
-
r = dict()
|
172
|
-
for k, v in DFAULT.contexts.items():
|
173
|
-
if v:
|
174
|
-
r[k] = v[-1]
|
175
|
-
for k, v in DFAULT.settings.items():
|
176
|
-
if k not in r:
|
177
|
-
r[k] = v
|
178
|
-
return r
|
179
|
-
|
180
|
-
|
181
|
-
def get_dt():
|
182
|
-
"""Get the numerical integrator precision.
|
183
|
-
|
184
|
-
Returns
|
185
|
-
-------
|
186
|
-
dt : float
|
187
|
-
Numerical integration precision.
|
188
|
-
"""
|
189
|
-
return get('dt')
|
190
|
-
|
191
|
-
|
192
|
-
def get_mode() -> Mode:
|
193
|
-
"""Get the default computing mode.
|
194
|
-
|
195
|
-
References
|
196
|
-
----------
|
197
|
-
mode: Mode
|
198
|
-
The default computing mode.
|
199
|
-
"""
|
200
|
-
return get('mode')
|
201
|
-
|
202
|
-
|
203
|
-
def get_platform() -> str:
|
204
|
-
"""Get the computing platform.
|
205
|
-
|
206
|
-
Returns
|
207
|
-
-------
|
208
|
-
platform: str
|
209
|
-
Either 'cpu', 'gpu' or 'tpu'.
|
210
|
-
"""
|
211
|
-
return devices()[0].platform
|
212
|
-
|
213
|
-
|
214
|
-
def get_host_device_count():
|
215
|
-
"""
|
216
|
-
Get the number of host devices.
|
217
|
-
|
218
|
-
Returns
|
219
|
-
-------
|
220
|
-
n: int
|
221
|
-
The number of host devices.
|
222
|
-
"""
|
223
|
-
xla_flags = os.getenv("XLA_FLAGS", "")
|
224
|
-
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
225
|
-
return int(match.group(1)) if match else 1
|
226
|
-
|
227
|
-
|
228
|
-
def _get_precision() -> int | str:
|
229
|
-
"""
|
230
|
-
Get the default precision.
|
231
|
-
|
232
|
-
Returns
|
233
|
-
-------
|
234
|
-
precision: int
|
235
|
-
The default precision.
|
236
|
-
"""
|
237
|
-
return get('precision')
|
238
|
-
|
239
|
-
|
240
|
-
def get_precision() -> int:
|
241
|
-
"""
|
242
|
-
Get the default precision.
|
243
|
-
|
244
|
-
Returns
|
245
|
-
-------
|
246
|
-
precision: int
|
247
|
-
The default precision.
|
248
|
-
"""
|
249
|
-
precision = get('precision')
|
250
|
-
if precision == 'bf16':
|
251
|
-
return 16
|
252
|
-
if isinstance(precision, int):
|
253
|
-
return precision
|
254
|
-
if isinstance(precision, str):
|
255
|
-
return int(precision)
|
256
|
-
raise ValueError(f'Unsupported precision: {precision}')
|
257
|
-
|
258
|
-
|
259
|
-
def set(
|
260
|
-
platform: str = None,
|
261
|
-
host_device_count: int = None,
|
262
|
-
precision: int | str = None,
|
263
|
-
mode: Mode = None,
|
264
|
-
**kwargs
|
265
|
-
):
|
266
|
-
"""
|
267
|
-
Set the global default computation environment.
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
Args:
|
272
|
-
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
273
|
-
host_device_count: int. The number of host devices.
|
274
|
-
precision: int, str. The default precision.
|
275
|
-
mode: Mode. The computing mode.
|
276
|
-
**kwargs: dict. Other environment settings.
|
277
|
-
"""
|
278
|
-
if platform is not None:
|
279
|
-
set_platform(platform)
|
280
|
-
if host_device_count is not None:
|
281
|
-
set_host_device_count(host_device_count)
|
282
|
-
if precision is not None:
|
283
|
-
_set_jax_precision(precision)
|
284
|
-
kwargs['precision'] = precision
|
285
|
-
if mode is not None:
|
286
|
-
assert isinstance(mode, Mode), 'mode must be a Mode instance.'
|
287
|
-
kwargs['mode'] = mode
|
288
|
-
|
289
|
-
# set default environment
|
290
|
-
DFAULT.settings.update(kwargs)
|
291
|
-
|
292
|
-
# update the environment functions
|
293
|
-
for k, v in kwargs.items():
|
294
|
-
if k in DFAULT.functions:
|
295
|
-
DFAULT.functions[k](v)
|
296
|
-
|
297
|
-
|
298
|
-
def set_host_device_count(n):
|
299
|
-
"""
|
300
|
-
By default, XLA considers all CPU cores as one device. This utility tells XLA
|
301
|
-
that there are `n` host (CPU) devices available to use. As a consequence, this
|
302
|
-
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
303
|
-
|
304
|
-
.. note:: This utility only takes effect at the beginning of your program.
|
305
|
-
Under the hood, this sets the environment variable
|
306
|
-
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
307
|
-
`[num_device]` is the desired number of CPU devices `n`.
|
308
|
-
|
309
|
-
.. warning:: Our understanding of the side effects of using the
|
310
|
-
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
|
311
|
-
observe some strange phenomenon when using this utility, please let us
|
312
|
-
know through our issue or forum page. More information is available in this
|
313
|
-
`JAX issue <https://github.com/google/jax/issues/1408>`_.
|
314
|
-
|
315
|
-
:param int n: number of devices to use.
|
316
|
-
"""
|
317
|
-
xla_flags = os.getenv("XLA_FLAGS", "")
|
318
|
-
xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
|
319
|
-
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
|
320
|
-
|
321
|
-
# update the environment functions
|
322
|
-
if 'host_device_count' in DFAULT.functions:
|
323
|
-
DFAULT.functions['host_device_count'](n)
|
324
|
-
|
325
|
-
|
326
|
-
def set_platform(platform: str):
|
327
|
-
"""
|
328
|
-
Changes platform to CPU, GPU, or TPU. This utility only takes
|
329
|
-
effect at the beginning of your program.
|
330
|
-
|
331
|
-
Args:
|
332
|
-
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
333
|
-
|
334
|
-
Raises:
|
335
|
-
ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
|
336
|
-
"""
|
337
|
-
assert platform in ['cpu', 'gpu', 'tpu']
|
338
|
-
config.update("jax_platform_name", platform)
|
339
|
-
|
340
|
-
# update the environment functions
|
341
|
-
if 'platform' in DFAULT.functions:
|
342
|
-
DFAULT.functions['platform'](platform)
|
343
|
-
|
344
|
-
|
345
|
-
def _set_jax_precision(precision: int | str):
|
346
|
-
"""
|
347
|
-
Set the default precision.
|
348
|
-
|
349
|
-
Args:
|
350
|
-
precision: int. The default precision.
|
351
|
-
"""
|
352
|
-
# assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
353
|
-
if precision in [64, '64']:
|
354
|
-
config.update("jax_enable_x64", True)
|
355
|
-
else:
|
356
|
-
config.update("jax_enable_x64", False)
|
357
|
-
|
358
|
-
|
359
|
-
@functools.lru_cache()
|
360
|
-
def _get_uint(precision: int):
|
361
|
-
if precision in [64, '64']:
|
362
|
-
return np.uint64
|
363
|
-
elif precision in [32, '32']:
|
364
|
-
return np.uint32
|
365
|
-
elif precision in [16, '16', 'bf16']:
|
366
|
-
return np.uint16
|
367
|
-
elif precision in [8, '8']:
|
368
|
-
return np.uint8
|
369
|
-
else:
|
370
|
-
raise ValueError(f'Unsupported precision: {precision}')
|
371
|
-
|
372
|
-
|
373
|
-
@functools.lru_cache()
|
374
|
-
def _get_int(precision: int):
|
375
|
-
if precision in [64, '64']:
|
376
|
-
return np.int64
|
377
|
-
elif precision in [32, '32']:
|
378
|
-
return np.int32
|
379
|
-
elif precision in [16, '16', 'bf16']:
|
380
|
-
return np.int16
|
381
|
-
elif precision in [8, '8']:
|
382
|
-
return np.int8
|
383
|
-
else:
|
384
|
-
raise ValueError(f'Unsupported precision: {precision}')
|
385
|
-
|
386
|
-
|
387
|
-
@functools.lru_cache()
|
388
|
-
def _get_float(precision: int):
|
389
|
-
if precision in [64, '64']:
|
390
|
-
return np.float64
|
391
|
-
elif precision in [32, '32']:
|
392
|
-
return np.float32
|
393
|
-
elif precision in [16, '16']:
|
394
|
-
return np.float16
|
395
|
-
elif precision in ['bf16']:
|
396
|
-
return jnp.bfloat16
|
397
|
-
elif precision in [8, '8']:
|
398
|
-
return jnp.float8_e5m2
|
399
|
-
else:
|
400
|
-
raise ValueError(f'Unsupported precision: {precision}')
|
401
|
-
|
402
|
-
|
403
|
-
@functools.lru_cache()
|
404
|
-
def _get_complex(precision: int):
|
405
|
-
if precision
|
406
|
-
return np.complex128
|
407
|
-
elif precision
|
408
|
-
return np.complex64
|
409
|
-
elif precision in [16, '16', 'bf16']:
|
410
|
-
return np.complex64
|
411
|
-
elif precision
|
412
|
-
return np.complex64
|
413
|
-
else:
|
414
|
-
raise ValueError(f'Unsupported precision: {precision}')
|
415
|
-
|
416
|
-
|
417
|
-
def dftype() -> DTypeLike:
|
418
|
-
"""
|
419
|
-
Default floating data type.
|
420
|
-
|
421
|
-
This function returns the default floating data type based on the current precision.
|
422
|
-
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
423
|
-
you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
|
424
|
-
|
425
|
-
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
426
|
-
|
427
|
-
>>> import brainstate as brainstate
|
428
|
-
>>> import numpy as np
|
429
|
-
>>> with brainstate.environ.context(precision=32):
|
430
|
-
... a = np.zeros(1, dtype=brainstate.environ.dftype())
|
431
|
-
>>> print(a.dtype)
|
432
|
-
|
433
|
-
Returns
|
434
|
-
-------
|
435
|
-
float_dtype: DTypeLike
|
436
|
-
The default floating data type.
|
437
|
-
"""
|
438
|
-
return _get_float(_get_precision())
|
439
|
-
|
440
|
-
|
441
|
-
def ditype() -> DTypeLike:
|
442
|
-
"""
|
443
|
-
Default integer data type.
|
444
|
-
|
445
|
-
This function returns the default integer data type based on the current precision.
|
446
|
-
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
447
|
-
you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
|
448
|
-
|
449
|
-
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
450
|
-
|
451
|
-
>>> import brainstate as brainstate
|
452
|
-
>>> import numpy as np
|
453
|
-
>>> with brainstate.environ.context(precision=32):
|
454
|
-
... a = np.zeros(1, dtype=brainstate.environ.ditype())
|
455
|
-
>>> print(a.dtype)
|
456
|
-
int32
|
457
|
-
|
458
|
-
Returns
|
459
|
-
-------
|
460
|
-
int_dtype: DTypeLike
|
461
|
-
The default integer data type.
|
462
|
-
"""
|
463
|
-
return _get_int(_get_precision())
|
464
|
-
|
465
|
-
|
466
|
-
def dutype() -> DTypeLike:
|
467
|
-
"""
|
468
|
-
Default unsigned integer data type.
|
469
|
-
|
470
|
-
This function returns the default unsigned integer data type based on the current precision.
|
471
|
-
If you want the data dtype is changed with the setting of the precision
|
472
|
-
by ``brainstate.environ.set(precision)``, you can use this function to get the default
|
473
|
-
unsigned integer data type, and create the data by using ``dtype=dutype()``.
|
474
|
-
|
475
|
-
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
476
|
-
|
477
|
-
>>> import brainstate as brainstate
|
478
|
-
>>> import numpy as np
|
479
|
-
>>> with brainstate.environ.context(precision=32):
|
480
|
-
... a = np.zeros(1, dtype=brainstate.environ.dutype())
|
481
|
-
>>> print(a.dtype)
|
482
|
-
uint32
|
483
|
-
|
484
|
-
Returns
|
485
|
-
-------
|
486
|
-
uint_dtype: DTypeLike
|
487
|
-
The default unsigned integer data type.
|
488
|
-
"""
|
489
|
-
return _get_uint(_get_precision())
|
490
|
-
|
491
|
-
|
492
|
-
def dctype() -> DTypeLike:
|
493
|
-
"""
|
494
|
-
Default complex data type.
|
495
|
-
|
496
|
-
This function returns the default complex data type based on the current precision.
|
497
|
-
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
498
|
-
you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
|
499
|
-
|
500
|
-
For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
|
501
|
-
|
502
|
-
>>> import brainstate as brainstate
|
503
|
-
>>> import numpy as np
|
504
|
-
>>> with brainstate.environ.context(precision=32):
|
505
|
-
... a = np.zeros(1, dtype=brainstate.environ.dctype())
|
506
|
-
>>> print(a.dtype)
|
507
|
-
complex64
|
508
|
-
|
509
|
-
Returns
|
510
|
-
-------
|
511
|
-
complex_dtype: DTypeLike
|
512
|
-
The default complex data type.
|
513
|
-
"""
|
514
|
-
return _get_complex(_get_precision())
|
515
|
-
|
516
|
-
|
517
|
-
def tolerance():
|
518
|
-
if get_precision() == 64:
|
519
|
-
return jnp.array(1e-12, dtype=np.float64)
|
520
|
-
elif get_precision() == 32:
|
521
|
-
return jnp.array(1e-5, dtype=np.float32)
|
522
|
-
else:
|
523
|
-
return jnp.array(1e-2, dtype=np.float16)
|
524
|
-
|
525
|
-
|
526
|
-
def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
|
527
|
-
"""
|
528
|
-
Register a default behavior for a specific global key parameter.
|
529
|
-
|
530
|
-
For example, you can register a default behavior for the key 'dt' by::
|
531
|
-
|
532
|
-
>>> import brainstate as brainstate
|
533
|
-
>>> def dt_behavior(dt):
|
534
|
-
... print(f'Set the default dt to {dt}.')
|
535
|
-
...
|
536
|
-
>>> brainstate.environ.register_default_behavior('dt', dt_behavior)
|
537
|
-
|
538
|
-
Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
|
539
|
-
`dt_behavior` will be called with
|
540
|
-
`dt_behavior(0.1)`.
|
541
|
-
|
542
|
-
>>> brainstate.environ.set(dt=0.1)
|
543
|
-
Set the default dt to 0.1.
|
544
|
-
>>> with brainstate.environ.context(dt=0.2):
|
545
|
-
... pass
|
546
|
-
Set the default dt to 0.2.
|
547
|
-
Set the default dt to 0.1.
|
548
|
-
|
549
|
-
|
550
|
-
Args:
|
551
|
-
key: str. The key to register.
|
552
|
-
behavior: Callable. The behavior to register. It should be a callable.
|
553
|
-
replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
|
554
|
-
|
555
|
-
"""
|
556
|
-
assert isinstance(key, str), 'key must be a string.'
|
557
|
-
assert callable(behavior), 'behavior must be a callable.'
|
558
|
-
if not replace_if_exist:
|
559
|
-
assert key not in DFAULT.functions, f'{key} has been registered.'
|
560
|
-
DFAULT.functions[key] = behavior
|
561
|
-
|
562
|
-
|
563
|
-
set(precision=32)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import contextlib
|
19
|
+
import dataclasses
|
20
|
+
import functools
|
21
|
+
import os
|
22
|
+
import re
|
23
|
+
import threading
|
24
|
+
from collections import defaultdict
|
25
|
+
from typing import Any, Callable, Dict, Hashable
|
26
|
+
|
27
|
+
import numpy as np
|
28
|
+
from jax import config, devices, numpy as jnp
|
29
|
+
from jax.typing import DTypeLike
|
30
|
+
|
31
|
+
from .mixin import Mode
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
# functions for environment settings
|
35
|
+
'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
|
36
|
+
# functions for getting default behaviors
|
37
|
+
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
|
38
|
+
# functions for default data types
|
39
|
+
'dftype', 'ditype', 'dutype', 'dctype',
|
40
|
+
# others
|
41
|
+
'tolerance', 'register_default_behavior',
|
42
|
+
]
|
43
|
+
|
44
|
+
# Default, there are several shared arguments in the global context.
|
45
|
+
I = 'i' # the index of the current computation.
|
46
|
+
T = 't' # the current time of the current computation.
|
47
|
+
JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
|
48
|
+
FIT = 'fit' # whether to fit the model.
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class DefaultContext(threading.local):
|
53
|
+
# default environment settings
|
54
|
+
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
55
|
+
# current environment settings
|
56
|
+
contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
57
|
+
# environment functions
|
58
|
+
functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
59
|
+
|
60
|
+
|
61
|
+
DFAULT = DefaultContext()
|
62
|
+
_NOT_PROVIDE = object()
|
63
|
+
|
64
|
+
|
65
|
+
@contextlib.contextmanager
|
66
|
+
def context(**kwargs):
|
67
|
+
r"""
|
68
|
+
Context-manager that sets a computing environment for brain dynamics computation.
|
69
|
+
|
70
|
+
In BrainPy, there are several basic computation settings when constructing models,
|
71
|
+
including ``mode`` for controlling model computing behavior, ``dt`` for numerical
|
72
|
+
integration, ``int_`` for integer precision, and ``float_`` for floating precision.
|
73
|
+
:py:class:`~.environment`` provides a context for model construction and
|
74
|
+
computation. In this temporal environment, models are constructed with the given
|
75
|
+
``mode``, ``dt``, ``int_``, etc., environment settings.
|
76
|
+
|
77
|
+
For instance::
|
78
|
+
|
79
|
+
>>> import brainstate as brainstate
|
80
|
+
>>> with brainstate.environ.context(dt=0.1) as env:
|
81
|
+
... dt = brainstate.environ.get('dt')
|
82
|
+
... print(env)
|
83
|
+
|
84
|
+
"""
|
85
|
+
if 'platform' in kwargs:
|
86
|
+
raise ValueError('\n'
|
87
|
+
'Cannot set platform in "context" environment. \n'
|
88
|
+
'You should set platform in the global environment by "set_platform()" or "set()".')
|
89
|
+
if 'host_device_count' in kwargs:
|
90
|
+
raise ValueError('Cannot set host_device_count in environment context. '
|
91
|
+
'Please use set_host_device_count() or set() for the global setting.')
|
92
|
+
|
93
|
+
if 'precision' in kwargs:
|
94
|
+
last_precision = _get_precision()
|
95
|
+
_set_jax_precision(kwargs['precision'])
|
96
|
+
|
97
|
+
try:
|
98
|
+
for k, v in kwargs.items():
|
99
|
+
|
100
|
+
# update the current environment
|
101
|
+
DFAULT.contexts[k].append(v)
|
102
|
+
|
103
|
+
# restore the environment functions
|
104
|
+
if k in DFAULT.functions:
|
105
|
+
DFAULT.functions[k](v)
|
106
|
+
|
107
|
+
# yield the current all environment information
|
108
|
+
yield all()
|
109
|
+
finally:
|
110
|
+
|
111
|
+
for k, v in kwargs.items():
|
112
|
+
|
113
|
+
# restore the current environment
|
114
|
+
DFAULT.contexts[k].pop()
|
115
|
+
|
116
|
+
# restore the environment functions
|
117
|
+
if k in DFAULT.functions:
|
118
|
+
DFAULT.functions[k](get(k))
|
119
|
+
|
120
|
+
if 'precision' in kwargs:
|
121
|
+
_set_jax_precision(last_precision)
|
122
|
+
|
123
|
+
|
124
|
+
def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
|
125
|
+
"""
|
126
|
+
Get one of the default computation environment.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
item: Any
|
131
|
+
The default computation environment.
|
132
|
+
"""
|
133
|
+
if key == 'platform':
|
134
|
+
return get_platform()
|
135
|
+
|
136
|
+
if key == 'host_device_count':
|
137
|
+
return get_host_device_count()
|
138
|
+
|
139
|
+
if key in DFAULT.contexts:
|
140
|
+
if len(DFAULT.contexts[key]) > 0:
|
141
|
+
return DFAULT.contexts[key][-1]
|
142
|
+
if key in DFAULT.settings:
|
143
|
+
return DFAULT.settings[key]
|
144
|
+
|
145
|
+
if default is _NOT_PROVIDE:
|
146
|
+
if desc is not None:
|
147
|
+
raise KeyError(
|
148
|
+
f"'{key}' is not found in the context. \n"
|
149
|
+
f"You can set it by `brainstate.share.context({key}=value)` "
|
150
|
+
f"locally or `brainstate.share.set({key}=value)` globally. \n"
|
151
|
+
f"Description: {desc}"
|
152
|
+
)
|
153
|
+
else:
|
154
|
+
raise KeyError(
|
155
|
+
f"'{key}' is not found in the context. \n"
|
156
|
+
f"You can set it by `brainstate.share.context({key}=value)` "
|
157
|
+
f"locally or `brainstate.share.set({key}=value)` globally."
|
158
|
+
)
|
159
|
+
return default
|
160
|
+
|
161
|
+
|
162
|
+
def all() -> dict:
|
163
|
+
"""
|
164
|
+
Get all the current default computation environment.
|
165
|
+
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
r: dict
|
169
|
+
The current default computation environment.
|
170
|
+
"""
|
171
|
+
r = dict()
|
172
|
+
for k, v in DFAULT.contexts.items():
|
173
|
+
if v:
|
174
|
+
r[k] = v[-1]
|
175
|
+
for k, v in DFAULT.settings.items():
|
176
|
+
if k not in r:
|
177
|
+
r[k] = v
|
178
|
+
return r
|
179
|
+
|
180
|
+
|
181
|
+
def get_dt():
|
182
|
+
"""Get the numerical integrator precision.
|
183
|
+
|
184
|
+
Returns
|
185
|
+
-------
|
186
|
+
dt : float
|
187
|
+
Numerical integration precision.
|
188
|
+
"""
|
189
|
+
return get('dt')
|
190
|
+
|
191
|
+
|
192
|
+
def get_mode() -> Mode:
|
193
|
+
"""Get the default computing mode.
|
194
|
+
|
195
|
+
References
|
196
|
+
----------
|
197
|
+
mode: Mode
|
198
|
+
The default computing mode.
|
199
|
+
"""
|
200
|
+
return get('mode')
|
201
|
+
|
202
|
+
|
203
|
+
def get_platform() -> str:
|
204
|
+
"""Get the computing platform.
|
205
|
+
|
206
|
+
Returns
|
207
|
+
-------
|
208
|
+
platform: str
|
209
|
+
Either 'cpu', 'gpu' or 'tpu'.
|
210
|
+
"""
|
211
|
+
return devices()[0].platform
|
212
|
+
|
213
|
+
|
214
|
+
def get_host_device_count():
|
215
|
+
"""
|
216
|
+
Get the number of host devices.
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
n: int
|
221
|
+
The number of host devices.
|
222
|
+
"""
|
223
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
224
|
+
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
225
|
+
return int(match.group(1)) if match else 1
|
226
|
+
|
227
|
+
|
228
|
+
def _get_precision() -> int | str:
|
229
|
+
"""
|
230
|
+
Get the default precision.
|
231
|
+
|
232
|
+
Returns
|
233
|
+
-------
|
234
|
+
precision: int
|
235
|
+
The default precision.
|
236
|
+
"""
|
237
|
+
return get('precision')
|
238
|
+
|
239
|
+
|
240
|
+
def get_precision() -> int:
|
241
|
+
"""
|
242
|
+
Get the default precision.
|
243
|
+
|
244
|
+
Returns
|
245
|
+
-------
|
246
|
+
precision: int
|
247
|
+
The default precision.
|
248
|
+
"""
|
249
|
+
precision = get('precision')
|
250
|
+
if precision == 'bf16':
|
251
|
+
return 16
|
252
|
+
if isinstance(precision, int):
|
253
|
+
return precision
|
254
|
+
if isinstance(precision, str):
|
255
|
+
return int(precision)
|
256
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
257
|
+
|
258
|
+
|
259
|
+
def set(
|
260
|
+
platform: str = None,
|
261
|
+
host_device_count: int = None,
|
262
|
+
precision: int | str = None,
|
263
|
+
mode: Mode = None,
|
264
|
+
**kwargs
|
265
|
+
):
|
266
|
+
"""
|
267
|
+
Set the global default computation environment.
|
268
|
+
|
269
|
+
|
270
|
+
|
271
|
+
Args:
|
272
|
+
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
273
|
+
host_device_count: int. The number of host devices.
|
274
|
+
precision: int, str. The default precision.
|
275
|
+
mode: Mode. The computing mode.
|
276
|
+
**kwargs: dict. Other environment settings.
|
277
|
+
"""
|
278
|
+
if platform is not None:
|
279
|
+
set_platform(platform)
|
280
|
+
if host_device_count is not None:
|
281
|
+
set_host_device_count(host_device_count)
|
282
|
+
if precision is not None:
|
283
|
+
_set_jax_precision(precision)
|
284
|
+
kwargs['precision'] = precision
|
285
|
+
if mode is not None:
|
286
|
+
assert isinstance(mode, Mode), 'mode must be a Mode instance.'
|
287
|
+
kwargs['mode'] = mode
|
288
|
+
|
289
|
+
# set default environment
|
290
|
+
DFAULT.settings.update(kwargs)
|
291
|
+
|
292
|
+
# update the environment functions
|
293
|
+
for k, v in kwargs.items():
|
294
|
+
if k in DFAULT.functions:
|
295
|
+
DFAULT.functions[k](v)
|
296
|
+
|
297
|
+
|
298
|
+
def set_host_device_count(n):
|
299
|
+
"""
|
300
|
+
By default, XLA considers all CPU cores as one device. This utility tells XLA
|
301
|
+
that there are `n` host (CPU) devices available to use. As a consequence, this
|
302
|
+
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
303
|
+
|
304
|
+
.. note:: This utility only takes effect at the beginning of your program.
|
305
|
+
Under the hood, this sets the environment variable
|
306
|
+
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
307
|
+
`[num_device]` is the desired number of CPU devices `n`.
|
308
|
+
|
309
|
+
.. warning:: Our understanding of the side effects of using the
|
310
|
+
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
|
311
|
+
observe some strange phenomenon when using this utility, please let us
|
312
|
+
know through our issue or forum page. More information is available in this
|
313
|
+
`JAX issue <https://github.com/google/jax/issues/1408>`_.
|
314
|
+
|
315
|
+
:param int n: number of devices to use.
|
316
|
+
"""
|
317
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
318
|
+
xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
|
319
|
+
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
|
320
|
+
|
321
|
+
# update the environment functions
|
322
|
+
if 'host_device_count' in DFAULT.functions:
|
323
|
+
DFAULT.functions['host_device_count'](n)
|
324
|
+
|
325
|
+
|
326
|
+
def set_platform(platform: str):
|
327
|
+
"""
|
328
|
+
Changes platform to CPU, GPU, or TPU. This utility only takes
|
329
|
+
effect at the beginning of your program.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
333
|
+
|
334
|
+
Raises:
|
335
|
+
ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
|
336
|
+
"""
|
337
|
+
assert platform in ['cpu', 'gpu', 'tpu']
|
338
|
+
config.update("jax_platform_name", platform)
|
339
|
+
|
340
|
+
# update the environment functions
|
341
|
+
if 'platform' in DFAULT.functions:
|
342
|
+
DFAULT.functions['platform'](platform)
|
343
|
+
|
344
|
+
|
345
|
+
def _set_jax_precision(precision: int | str):
|
346
|
+
"""
|
347
|
+
Set the default precision.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
precision: int. The default precision.
|
351
|
+
"""
|
352
|
+
# assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
353
|
+
if precision in [64, '64']:
|
354
|
+
config.update("jax_enable_x64", True)
|
355
|
+
else:
|
356
|
+
config.update("jax_enable_x64", False)
|
357
|
+
|
358
|
+
|
359
|
+
@functools.lru_cache()
|
360
|
+
def _get_uint(precision: int):
|
361
|
+
if precision in [64, '64']:
|
362
|
+
return np.uint64
|
363
|
+
elif precision in [32, '32']:
|
364
|
+
return np.uint32
|
365
|
+
elif precision in [16, '16', 'bf16']:
|
366
|
+
return np.uint16
|
367
|
+
elif precision in [8, '8']:
|
368
|
+
return np.uint8
|
369
|
+
else:
|
370
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
371
|
+
|
372
|
+
|
373
|
+
@functools.lru_cache()
|
374
|
+
def _get_int(precision: int):
|
375
|
+
if precision in [64, '64']:
|
376
|
+
return np.int64
|
377
|
+
elif precision in [32, '32']:
|
378
|
+
return np.int32
|
379
|
+
elif precision in [16, '16', 'bf16']:
|
380
|
+
return np.int16
|
381
|
+
elif precision in [8, '8']:
|
382
|
+
return np.int8
|
383
|
+
else:
|
384
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
385
|
+
|
386
|
+
|
387
|
+
@functools.lru_cache()
|
388
|
+
def _get_float(precision: int):
|
389
|
+
if precision in [64, '64']:
|
390
|
+
return np.float64
|
391
|
+
elif precision in [32, '32']:
|
392
|
+
return np.float32
|
393
|
+
elif precision in [16, '16']:
|
394
|
+
return np.float16
|
395
|
+
elif precision in ['bf16']:
|
396
|
+
return jnp.bfloat16
|
397
|
+
elif precision in [8, '8']:
|
398
|
+
return jnp.float8_e5m2
|
399
|
+
else:
|
400
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
401
|
+
|
402
|
+
|
403
|
+
@functools.lru_cache()
|
404
|
+
def _get_complex(precision: int):
|
405
|
+
if precision in [64, '64']:
|
406
|
+
return np.complex128
|
407
|
+
elif precision in [32, '32']:
|
408
|
+
return np.complex64
|
409
|
+
elif precision in [16, '16', 'bf16']:
|
410
|
+
return np.complex64
|
411
|
+
elif precision in [8, '8']:
|
412
|
+
return np.complex64
|
413
|
+
else:
|
414
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
415
|
+
|
416
|
+
|
417
|
+
def dftype() -> DTypeLike:
|
418
|
+
"""
|
419
|
+
Default floating data type.
|
420
|
+
|
421
|
+
This function returns the default floating data type based on the current precision.
|
422
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
423
|
+
you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
|
424
|
+
|
425
|
+
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
426
|
+
|
427
|
+
>>> import brainstate as brainstate
|
428
|
+
>>> import numpy as np
|
429
|
+
>>> with brainstate.environ.context(precision=32):
|
430
|
+
... a = np.zeros(1, dtype=brainstate.environ.dftype())
|
431
|
+
>>> print(a.dtype)
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
float_dtype: DTypeLike
|
436
|
+
The default floating data type.
|
437
|
+
"""
|
438
|
+
return _get_float(_get_precision())
|
439
|
+
|
440
|
+
|
441
|
+
def ditype() -> DTypeLike:
|
442
|
+
"""
|
443
|
+
Default integer data type.
|
444
|
+
|
445
|
+
This function returns the default integer data type based on the current precision.
|
446
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
447
|
+
you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
|
448
|
+
|
449
|
+
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
450
|
+
|
451
|
+
>>> import brainstate as brainstate
|
452
|
+
>>> import numpy as np
|
453
|
+
>>> with brainstate.environ.context(precision=32):
|
454
|
+
... a = np.zeros(1, dtype=brainstate.environ.ditype())
|
455
|
+
>>> print(a.dtype)
|
456
|
+
int32
|
457
|
+
|
458
|
+
Returns
|
459
|
+
-------
|
460
|
+
int_dtype: DTypeLike
|
461
|
+
The default integer data type.
|
462
|
+
"""
|
463
|
+
return _get_int(_get_precision())
|
464
|
+
|
465
|
+
|
466
|
+
def dutype() -> DTypeLike:
|
467
|
+
"""
|
468
|
+
Default unsigned integer data type.
|
469
|
+
|
470
|
+
This function returns the default unsigned integer data type based on the current precision.
|
471
|
+
If you want the data dtype is changed with the setting of the precision
|
472
|
+
by ``brainstate.environ.set(precision)``, you can use this function to get the default
|
473
|
+
unsigned integer data type, and create the data by using ``dtype=dutype()``.
|
474
|
+
|
475
|
+
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
476
|
+
|
477
|
+
>>> import brainstate as brainstate
|
478
|
+
>>> import numpy as np
|
479
|
+
>>> with brainstate.environ.context(precision=32):
|
480
|
+
... a = np.zeros(1, dtype=brainstate.environ.dutype())
|
481
|
+
>>> print(a.dtype)
|
482
|
+
uint32
|
483
|
+
|
484
|
+
Returns
|
485
|
+
-------
|
486
|
+
uint_dtype: DTypeLike
|
487
|
+
The default unsigned integer data type.
|
488
|
+
"""
|
489
|
+
return _get_uint(_get_precision())
|
490
|
+
|
491
|
+
|
492
|
+
def dctype() -> DTypeLike:
|
493
|
+
"""
|
494
|
+
Default complex data type.
|
495
|
+
|
496
|
+
This function returns the default complex data type based on the current precision.
|
497
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
498
|
+
you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
|
499
|
+
|
500
|
+
For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
|
501
|
+
|
502
|
+
>>> import brainstate as brainstate
|
503
|
+
>>> import numpy as np
|
504
|
+
>>> with brainstate.environ.context(precision=32):
|
505
|
+
... a = np.zeros(1, dtype=brainstate.environ.dctype())
|
506
|
+
>>> print(a.dtype)
|
507
|
+
complex64
|
508
|
+
|
509
|
+
Returns
|
510
|
+
-------
|
511
|
+
complex_dtype: DTypeLike
|
512
|
+
The default complex data type.
|
513
|
+
"""
|
514
|
+
return _get_complex(_get_precision())
|
515
|
+
|
516
|
+
|
517
|
+
def tolerance():
|
518
|
+
if get_precision() == 64:
|
519
|
+
return jnp.array(1e-12, dtype=np.float64)
|
520
|
+
elif get_precision() == 32:
|
521
|
+
return jnp.array(1e-5, dtype=np.float32)
|
522
|
+
else:
|
523
|
+
return jnp.array(1e-2, dtype=np.float16)
|
524
|
+
|
525
|
+
|
526
|
+
def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
|
527
|
+
"""
|
528
|
+
Register a default behavior for a specific global key parameter.
|
529
|
+
|
530
|
+
For example, you can register a default behavior for the key 'dt' by::
|
531
|
+
|
532
|
+
>>> import brainstate as brainstate
|
533
|
+
>>> def dt_behavior(dt):
|
534
|
+
... print(f'Set the default dt to {dt}.')
|
535
|
+
...
|
536
|
+
>>> brainstate.environ.register_default_behavior('dt', dt_behavior)
|
537
|
+
|
538
|
+
Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
|
539
|
+
`dt_behavior` will be called with
|
540
|
+
`dt_behavior(0.1)`.
|
541
|
+
|
542
|
+
>>> brainstate.environ.set(dt=0.1)
|
543
|
+
Set the default dt to 0.1.
|
544
|
+
>>> with brainstate.environ.context(dt=0.2):
|
545
|
+
... pass
|
546
|
+
Set the default dt to 0.2.
|
547
|
+
Set the default dt to 0.1.
|
548
|
+
|
549
|
+
|
550
|
+
Args:
|
551
|
+
key: str. The key to register.
|
552
|
+
behavior: Callable. The behavior to register. It should be a callable.
|
553
|
+
replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
|
554
|
+
|
555
|
+
"""
|
556
|
+
assert isinstance(key, str), 'key must be a string.'
|
557
|
+
assert callable(behavior), 'behavior must be a callable.'
|
558
|
+
if not replace_if_exist:
|
559
|
+
assert key not in DFAULT.functions, f'{key} has been registered.'
|
560
|
+
DFAULT.functions[key] = behavior
|
561
|
+
|
562
|
+
|
563
|
+
set(precision=32)
|