brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -1,29 +1,47 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
1
16
|
# -*- coding: utf-8 -*-
|
2
17
|
|
18
|
+
from __future__ import annotations
|
3
19
|
|
4
20
|
import contextlib
|
21
|
+
import dataclasses
|
5
22
|
import functools
|
6
23
|
import os
|
7
24
|
import re
|
25
|
+
import threading
|
8
26
|
from collections import defaultdict
|
9
|
-
from typing import Any, Callable
|
27
|
+
from typing import Any, Callable, Dict, Hashable
|
10
28
|
|
11
29
|
import numpy as np
|
12
30
|
from jax import config, devices, numpy as jnp
|
13
|
-
from jax.
|
31
|
+
from jax.typing import DTypeLike
|
14
32
|
|
15
33
|
from .mixin import Mode
|
16
|
-
from .util import MemScaling
|
34
|
+
from .util import MemScaling
|
17
35
|
|
18
36
|
__all__ = [
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
37
|
+
# functions for environment settings
|
38
|
+
'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
|
39
|
+
# functions for getting default behaviors
|
40
|
+
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
41
|
+
# functions for default data types
|
42
|
+
'dftype', 'ditype', 'dutype', 'dctype',
|
43
|
+
# others
|
44
|
+
'tolerance', 'register_default_behavior',
|
27
45
|
]
|
28
46
|
|
29
47
|
# Default, there are several shared arguments in the global context.
|
@@ -32,196 +50,205 @@ T = 't' # the current time of the current computation.
|
|
32
50
|
JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
|
33
51
|
FIT = 'fit' # whether to fit the model.
|
34
52
|
|
53
|
+
|
54
|
+
@dataclasses.dataclass
|
55
|
+
class DefaultContext(threading.local):
|
56
|
+
# default environment settings
|
57
|
+
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
58
|
+
# current environment settings
|
59
|
+
contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
60
|
+
# environment functions
|
61
|
+
functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
62
|
+
|
63
|
+
|
64
|
+
DFAULT = DefaultContext()
|
35
65
|
_NOT_PROVIDE = object()
|
36
|
-
_environment_defaults = dict() # default environment settings
|
37
|
-
_environment_contexts = defaultdict(list) # current environment settings
|
38
|
-
_environment_functions = dict() # environment functions
|
39
66
|
|
40
67
|
|
41
68
|
@contextlib.contextmanager
|
42
69
|
def context(**kwargs):
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
if 'precision' in kwargs:
|
70
|
-
last_precision = get_precision()
|
71
|
-
_set_jax_precision(kwargs['precision'])
|
72
|
-
|
73
|
-
try:
|
74
|
-
for k, v in kwargs.items():
|
70
|
+
r"""
|
71
|
+
Context-manager that sets a computing environment for brain dynamics computation.
|
72
|
+
|
73
|
+
In BrainPy, there are several basic computation settings when constructing models,
|
74
|
+
including ``mode`` for controlling model computing behavior, ``dt`` for numerical
|
75
|
+
integration, ``int_`` for integer precision, and ``float_`` for floating precision.
|
76
|
+
:py:class:`~.environment`` provides a context for model construction and
|
77
|
+
computation. In this temporal environment, models are constructed with the given
|
78
|
+
``mode``, ``dt``, ``int_``, etc., environment settings.
|
79
|
+
|
80
|
+
For instance::
|
81
|
+
|
82
|
+
>>> import brainstate as bst
|
83
|
+
>>> with bst.environ.context(dt=0.1) as env:
|
84
|
+
... dt = bst.environ.get('dt')
|
85
|
+
... print(env)
|
86
|
+
|
87
|
+
"""
|
88
|
+
if 'platform' in kwargs:
|
89
|
+
raise ValueError('\n'
|
90
|
+
'Cannot set platform in "context" environment. \n'
|
91
|
+
'You should set platform in the global environment by "set_platform()" or "set()".')
|
92
|
+
if 'host_device_count' in kwargs:
|
93
|
+
raise ValueError('Cannot set host_device_count in environment context. '
|
94
|
+
'Please use set_host_device_count() or set() for the global setting.')
|
75
95
|
|
76
|
-
|
77
|
-
|
96
|
+
if 'precision' in kwargs:
|
97
|
+
last_precision = get_precision()
|
98
|
+
_set_jax_precision(kwargs['precision'])
|
78
99
|
|
79
|
-
|
80
|
-
|
81
|
-
_environment_functions[k](v)
|
100
|
+
try:
|
101
|
+
for k, v in kwargs.items():
|
82
102
|
|
83
|
-
|
84
|
-
|
85
|
-
finally:
|
103
|
+
# update the current environment
|
104
|
+
DFAULT.contexts[k].append(v)
|
86
105
|
|
87
|
-
|
106
|
+
# restore the environment functions
|
107
|
+
if k in DFAULT.functions:
|
108
|
+
DFAULT.functions[k](v)
|
88
109
|
|
89
|
-
|
90
|
-
|
110
|
+
# yield the current all environment information
|
111
|
+
yield all()
|
112
|
+
finally:
|
91
113
|
|
92
|
-
|
93
|
-
if k in _environment_functions:
|
94
|
-
_environment_functions[k](get(k))
|
114
|
+
for k, v in kwargs.items():
|
95
115
|
|
96
|
-
|
97
|
-
|
116
|
+
# restore the current environment
|
117
|
+
DFAULT.contexts[k].pop()
|
118
|
+
|
119
|
+
# restore the environment functions
|
120
|
+
if k in DFAULT.functions:
|
121
|
+
DFAULT.functions[k](get(k))
|
122
|
+
|
123
|
+
if 'precision' in kwargs:
|
124
|
+
_set_jax_precision(last_precision)
|
98
125
|
|
99
126
|
|
100
127
|
def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
128
|
+
"""
|
129
|
+
Get one of the default computation environment.
|
130
|
+
|
131
|
+
Returns
|
132
|
+
-------
|
133
|
+
item: Any
|
134
|
+
The default computation environment.
|
135
|
+
"""
|
136
|
+
if key == 'platform':
|
137
|
+
return get_platform()
|
138
|
+
|
139
|
+
if key == 'host_device_count':
|
140
|
+
return get_host_device_count()
|
141
|
+
|
142
|
+
if key in DFAULT.contexts:
|
143
|
+
if len(DFAULT.contexts[key]) > 0:
|
144
|
+
return DFAULT.contexts[key][-1]
|
145
|
+
if key in DFAULT.settings:
|
146
|
+
return DFAULT.settings[key]
|
147
|
+
|
148
|
+
if default is _NOT_PROVIDE:
|
149
|
+
if desc is not None:
|
150
|
+
raise KeyError(
|
151
|
+
f"'{key}' is not found in the context. \n"
|
152
|
+
f"You can set it by `brainstate.share.context({key}=value)` "
|
153
|
+
f"locally or `brainstate.share.set({key}=value)` globally. \n"
|
154
|
+
f"Description: {desc}"
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
raise KeyError(
|
158
|
+
f"'{key}' is not found in the context. \n"
|
159
|
+
f"You can set it by `brainstate.share.context({key}=value)` "
|
160
|
+
f"locally or `brainstate.share.set({key}=value)` globally."
|
161
|
+
)
|
162
|
+
return default
|
136
163
|
|
137
164
|
|
138
165
|
def all() -> dict:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
166
|
+
"""
|
167
|
+
Get all the current default computation environment.
|
168
|
+
|
169
|
+
Returns
|
170
|
+
-------
|
171
|
+
r: dict
|
172
|
+
The current default computation environment.
|
173
|
+
"""
|
174
|
+
r = dict()
|
175
|
+
for k, v in DFAULT.contexts.items():
|
176
|
+
if v:
|
177
|
+
r[k] = v[-1]
|
178
|
+
for k, v in DFAULT.settings.items():
|
179
|
+
if k not in r:
|
180
|
+
r[k] = v
|
181
|
+
return r
|
155
182
|
|
156
183
|
|
157
184
|
def get_dt():
|
158
|
-
|
185
|
+
"""Get the numerical integrator precision.
|
159
186
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
187
|
+
Returns
|
188
|
+
-------
|
189
|
+
dt : float
|
190
|
+
Numerical integration precision.
|
191
|
+
"""
|
192
|
+
return get('dt')
|
166
193
|
|
167
194
|
|
168
195
|
def get_mode() -> Mode:
|
169
|
-
|
196
|
+
"""Get the default computing mode.
|
170
197
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
198
|
+
References
|
199
|
+
----------
|
200
|
+
mode: Mode
|
201
|
+
The default computing mode.
|
202
|
+
"""
|
203
|
+
return get('mode')
|
177
204
|
|
178
205
|
|
179
206
|
def get_mem_scaling() -> MemScaling:
|
180
|
-
|
207
|
+
"""Get the default computing membrane_scaling.
|
181
208
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
209
|
+
Returns
|
210
|
+
-------
|
211
|
+
membrane_scaling: MemScaling
|
212
|
+
The default computing membrane_scaling.
|
213
|
+
"""
|
214
|
+
return get('mem_scaling')
|
188
215
|
|
189
216
|
|
190
217
|
def get_platform() -> str:
|
191
|
-
|
218
|
+
"""Get the computing platform.
|
192
219
|
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
220
|
+
Returns
|
221
|
+
-------
|
222
|
+
platform: str
|
223
|
+
Either 'cpu', 'gpu' or 'tpu'.
|
224
|
+
"""
|
225
|
+
return devices()[0].platform
|
199
226
|
|
200
227
|
|
201
228
|
def get_host_device_count():
|
202
|
-
|
203
|
-
|
229
|
+
"""
|
230
|
+
Get the number of host devices.
|
204
231
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
232
|
+
Returns
|
233
|
+
-------
|
234
|
+
n: int
|
235
|
+
The number of host devices.
|
236
|
+
"""
|
237
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
238
|
+
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
239
|
+
return int(match.group(1)) if match else 1
|
213
240
|
|
214
241
|
|
215
242
|
def get_precision() -> int:
|
216
|
-
|
217
|
-
|
243
|
+
"""
|
244
|
+
Get the default precision.
|
218
245
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
246
|
+
Returns
|
247
|
+
-------
|
248
|
+
precision: int
|
249
|
+
The default precision.
|
250
|
+
"""
|
251
|
+
return get('precision')
|
225
252
|
|
226
253
|
|
227
254
|
def set(
|
@@ -232,300 +259,300 @@ def set(
|
|
232
259
|
mode: Mode = None,
|
233
260
|
**kwargs
|
234
261
|
):
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
262
|
+
"""
|
263
|
+
Set the global default computation environment.
|
264
|
+
|
265
|
+
|
266
|
+
|
267
|
+
Args:
|
268
|
+
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
269
|
+
host_device_count: int. The number of host devices.
|
270
|
+
mem_scaling: MemScaling. The membrane scaling.
|
271
|
+
precision: int. The default precision.
|
272
|
+
mode: Mode. The computing mode.
|
273
|
+
**kwargs: dict. Other environment settings.
|
274
|
+
"""
|
275
|
+
if platform is not None:
|
276
|
+
set_platform(platform)
|
277
|
+
if host_device_count is not None:
|
278
|
+
set_host_device_count(host_device_count)
|
279
|
+
if mem_scaling is not None:
|
280
|
+
assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
|
281
|
+
kwargs['mem_scaling'] = mem_scaling
|
282
|
+
if precision is not None:
|
283
|
+
_set_jax_precision(precision)
|
284
|
+
kwargs['precision'] = precision
|
285
|
+
if mode is not None:
|
286
|
+
assert isinstance(mode, Mode), 'mode must be a Mode instance.'
|
287
|
+
kwargs['mode'] = mode
|
288
|
+
|
289
|
+
# set default environment
|
290
|
+
DFAULT.settings.update(kwargs)
|
291
|
+
|
292
|
+
# update the environment functions
|
293
|
+
for k, v in kwargs.items():
|
294
|
+
if k in DFAULT.functions:
|
295
|
+
DFAULT.functions[k](v)
|
269
296
|
|
270
297
|
|
271
298
|
def set_host_device_count(n):
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
299
|
+
"""
|
300
|
+
By default, XLA considers all CPU cores as one device. This utility tells XLA
|
301
|
+
that there are `n` host (CPU) devices available to use. As a consequence, this
|
302
|
+
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
276
303
|
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
304
|
+
.. note:: This utility only takes effect at the beginning of your program.
|
305
|
+
Under the hood, this sets the environment variable
|
306
|
+
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
307
|
+
`[num_device]` is the desired number of CPU devices `n`.
|
281
308
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
309
|
+
.. warning:: Our understanding of the side effects of using the
|
310
|
+
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
|
311
|
+
observe some strange phenomenon when using this utility, please let us
|
312
|
+
know through our issue or forum page. More information is available in this
|
313
|
+
`JAX issue <https://github.com/google/jax/issues/1408>`_.
|
287
314
|
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
315
|
+
:param int n: number of devices to use.
|
316
|
+
"""
|
317
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
318
|
+
xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
|
319
|
+
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
|
293
320
|
|
294
|
-
|
295
|
-
|
296
|
-
|
321
|
+
# update the environment functions
|
322
|
+
if 'host_device_count' in DFAULT.functions:
|
323
|
+
DFAULT.functions['host_device_count'](n)
|
297
324
|
|
298
325
|
|
299
326
|
def set_platform(platform: str):
|
300
|
-
|
301
|
-
|
302
|
-
|
327
|
+
"""
|
328
|
+
Changes platform to CPU, GPU, or TPU. This utility only takes
|
329
|
+
effect at the beginning of your program.
|
303
330
|
|
304
|
-
|
305
|
-
|
331
|
+
Args:
|
332
|
+
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
306
333
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
334
|
+
Raises:
|
335
|
+
ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
|
336
|
+
"""
|
337
|
+
assert platform in ['cpu', 'gpu', 'tpu']
|
338
|
+
config.update("jax_platform_name", platform)
|
312
339
|
|
313
|
-
|
314
|
-
|
315
|
-
|
340
|
+
# update the environment functions
|
341
|
+
if 'platform' in DFAULT.functions:
|
342
|
+
DFAULT.functions['platform'](platform)
|
316
343
|
|
317
344
|
|
318
345
|
def _set_jax_precision(precision: int):
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
346
|
+
"""
|
347
|
+
Set the default precision.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
precision: int. The default precision.
|
351
|
+
"""
|
352
|
+
assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
|
353
|
+
if precision == 64:
|
354
|
+
config.update("jax_enable_x64", True)
|
355
|
+
else:
|
356
|
+
config.update("jax_enable_x64", False)
|
330
357
|
|
331
358
|
|
332
359
|
@functools.lru_cache()
|
333
360
|
def _get_uint(precision: int):
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
361
|
+
if precision == 64:
|
362
|
+
return np.uint64
|
363
|
+
elif precision == 32:
|
364
|
+
return np.uint32
|
365
|
+
elif precision == 16:
|
366
|
+
return np.uint16
|
367
|
+
elif precision == 8:
|
368
|
+
return np.uint8
|
369
|
+
else:
|
370
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
344
371
|
|
345
372
|
|
346
373
|
@functools.lru_cache()
|
347
374
|
def _get_int(precision: int):
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
375
|
+
if precision == 64:
|
376
|
+
return np.int64
|
377
|
+
elif precision == 32:
|
378
|
+
return np.int32
|
379
|
+
elif precision == 16:
|
380
|
+
return np.int16
|
381
|
+
elif precision == 8:
|
382
|
+
return np.int8
|
383
|
+
else:
|
384
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
358
385
|
|
359
386
|
|
360
387
|
@functools.lru_cache()
|
361
388
|
def _get_float(precision: int):
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
389
|
+
if precision == 64:
|
390
|
+
return np.float64
|
391
|
+
elif precision == 32:
|
392
|
+
return np.float32
|
393
|
+
elif precision == 16:
|
394
|
+
return jnp.bfloat16
|
395
|
+
# return np.float16
|
396
|
+
else:
|
397
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
371
398
|
|
372
399
|
|
373
400
|
@functools.lru_cache()
|
374
401
|
def _get_complex(precision: int):
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
402
|
+
if precision == 64:
|
403
|
+
return np.complex128
|
404
|
+
elif precision == 32:
|
405
|
+
return np.complex64
|
406
|
+
elif precision == 16:
|
407
|
+
return np.complex32
|
408
|
+
else:
|
409
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
383
410
|
|
384
411
|
|
385
412
|
def dftype() -> DTypeLike:
|
386
|
-
|
387
|
-
|
413
|
+
"""
|
414
|
+
Default floating data type.
|
388
415
|
|
389
|
-
|
390
|
-
|
391
|
-
|
416
|
+
This function returns the default floating data type based on the current precision.
|
417
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
418
|
+
you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
|
392
419
|
|
393
|
-
|
420
|
+
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
394
421
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
422
|
+
>>> import brainstate as bst
|
423
|
+
>>> import numpy as np
|
424
|
+
>>> with bst.environ.context(precision=32):
|
425
|
+
... a = np.zeros(1, dtype=bst.environ.dftype())
|
426
|
+
>>> print(a.dtype)
|
400
427
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
428
|
+
Returns
|
429
|
+
-------
|
430
|
+
float_dtype: DTypeLike
|
431
|
+
The default floating data type.
|
432
|
+
"""
|
433
|
+
return _get_float(get_precision())
|
407
434
|
|
408
435
|
|
409
436
|
def ditype() -> DTypeLike:
|
410
|
-
|
411
|
-
|
437
|
+
"""
|
438
|
+
Default integer data type.
|
412
439
|
|
413
|
-
|
414
|
-
|
415
|
-
|
440
|
+
This function returns the default integer data type based on the current precision.
|
441
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
442
|
+
you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
|
416
443
|
|
417
|
-
|
444
|
+
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
418
445
|
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
446
|
+
>>> import brainstate as bst
|
447
|
+
>>> import numpy as np
|
448
|
+
>>> with bst.environ.context(precision=32):
|
449
|
+
... a = np.zeros(1, dtype=bst.environ.ditype())
|
450
|
+
>>> print(a.dtype)
|
451
|
+
int32
|
425
452
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
453
|
+
Returns
|
454
|
+
-------
|
455
|
+
int_dtype: DTypeLike
|
456
|
+
The default integer data type.
|
457
|
+
"""
|
458
|
+
return _get_int(get_precision())
|
432
459
|
|
433
460
|
|
434
461
|
def dutype() -> DTypeLike:
|
435
|
-
|
436
|
-
|
462
|
+
"""
|
463
|
+
Default unsigned integer data type.
|
437
464
|
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
465
|
+
This function returns the default unsigned integer data type based on the current precision.
|
466
|
+
If you want the data dtype is changed with the setting of the precision
|
467
|
+
by ``brainstate.environ.set(precision)``, you can use this function to get the default
|
468
|
+
unsigned integer data type, and create the data by using ``dtype=dutype()``.
|
442
469
|
|
443
|
-
|
470
|
+
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
444
471
|
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
472
|
+
>>> import brainstate as bst
|
473
|
+
>>> import numpy as np
|
474
|
+
>>> with bst.environ.context(precision=32):
|
475
|
+
... a = np.zeros(1, dtype=bst.environ.dutype())
|
476
|
+
>>> print(a.dtype)
|
477
|
+
uint32
|
451
478
|
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
479
|
+
Returns
|
480
|
+
-------
|
481
|
+
uint_dtype: DTypeLike
|
482
|
+
The default unsigned integer data type.
|
483
|
+
"""
|
484
|
+
return _get_uint(get_precision())
|
458
485
|
|
459
486
|
|
460
487
|
def dctype() -> DTypeLike:
|
461
|
-
|
462
|
-
|
488
|
+
"""
|
489
|
+
Default complex data type.
|
463
490
|
|
464
|
-
|
465
|
-
|
466
|
-
|
491
|
+
This function returns the default complex data type based on the current precision.
|
492
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
493
|
+
you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
|
467
494
|
|
468
|
-
|
495
|
+
For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
|
469
496
|
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
497
|
+
>>> import brainstate as bst
|
498
|
+
>>> import numpy as np
|
499
|
+
>>> with bst.environ.context(precision=32):
|
500
|
+
... a = np.zeros(1, dtype=bst.environ.dctype())
|
501
|
+
>>> print(a.dtype)
|
502
|
+
complex64
|
476
503
|
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
504
|
+
Returns
|
505
|
+
-------
|
506
|
+
complex_dtype: DTypeLike
|
507
|
+
The default complex data type.
|
508
|
+
"""
|
509
|
+
return _get_complex(get_precision())
|
483
510
|
|
484
511
|
|
485
512
|
def tolerance():
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
513
|
+
if get_precision() == 64:
|
514
|
+
return jnp.array(1e-12, dtype=np.float64)
|
515
|
+
elif get_precision() == 32:
|
516
|
+
return jnp.array(1e-5, dtype=np.float32)
|
517
|
+
else:
|
518
|
+
return jnp.array(1e-2, dtype=np.float16)
|
492
519
|
|
493
520
|
|
494
521
|
def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
|
495
|
-
|
496
|
-
|
522
|
+
"""
|
523
|
+
Register a default behavior for a specific global key parameter.
|
497
524
|
|
498
|
-
|
525
|
+
For example, you can register a default behavior for the key 'dt' by::
|
499
526
|
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
527
|
+
>>> import brainstate as bst
|
528
|
+
>>> def dt_behavior(dt):
|
529
|
+
... print(f'Set the default dt to {dt}.')
|
530
|
+
...
|
531
|
+
>>> bst.environ.register_default_behavior('dt', dt_behavior)
|
505
532
|
|
506
|
-
|
507
|
-
|
508
|
-
|
533
|
+
Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
|
534
|
+
`dt_behavior` will be called with
|
535
|
+
`dt_behavior(0.1)`.
|
509
536
|
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
537
|
+
>>> bst.environ.set(dt=0.1)
|
538
|
+
Set the default dt to 0.1.
|
539
|
+
>>> with bst.environ.context(dt=0.2):
|
540
|
+
... pass
|
541
|
+
Set the default dt to 0.2.
|
542
|
+
Set the default dt to 0.1.
|
516
543
|
|
517
544
|
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
545
|
+
Args:
|
546
|
+
key: str. The key to register.
|
547
|
+
behavior: Callable. The behavior to register. It should be a callable.
|
548
|
+
replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
|
522
549
|
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
550
|
+
"""
|
551
|
+
assert isinstance(key, str), 'key must be a string.'
|
552
|
+
assert callable(behavior), 'behavior must be a callable.'
|
553
|
+
if not replace_if_exist:
|
554
|
+
assert key not in DFAULT.functions, f'{key} has been registered.'
|
555
|
+
DFAULT.functions[key] = behavior
|
529
556
|
|
530
557
|
|
531
|
-
set(
|
558
|
+
set(precision=32)
|