brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/environ.py
ADDED
@@ -0,0 +1,375 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
|
3
|
+
|
4
|
+
import contextlib
|
5
|
+
import functools
|
6
|
+
import os
|
7
|
+
import re
|
8
|
+
from collections import defaultdict
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
from jax import config, devices, numpy as jnp
|
13
|
+
from jax._src.typing import DTypeLike
|
14
|
+
|
15
|
+
from .mixin import Mode
|
16
|
+
from .util import MemScaling, IdMemScaling
|
17
|
+
|
18
|
+
__all__ = [
|
19
|
+
'set', 'context', 'get', 'all',
|
20
|
+
'set_host_device_count', 'set_platform',
|
21
|
+
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
22
|
+
'tolerance',
|
23
|
+
'dftype', 'ditype', 'dutype', 'dctype',
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
# Default, there are several shared arguments in the global context.
|
28
|
+
I = 'i' # the index of the current computation.
|
29
|
+
T = 't' # the current time of the current computation.
|
30
|
+
JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
|
31
|
+
FIT = 'fit' # whether to fit the model.
|
32
|
+
|
33
|
+
_NOT_PROVIDE = object()
|
34
|
+
_environment_defaults = dict()
|
35
|
+
_environment_contexts = defaultdict(list)
|
36
|
+
|
37
|
+
|
38
|
+
@contextlib.contextmanager
|
39
|
+
def context(**kwargs):
|
40
|
+
r"""
|
41
|
+
Context-manager that sets a computing environment for brain dynamics computation.
|
42
|
+
|
43
|
+
In BrainPy, there are several basic computation settings when constructing models,
|
44
|
+
including ``mode`` for controlling model computing behavior, ``dt`` for numerical
|
45
|
+
integration, ``int_`` for integer precision, and ``float_`` for floating precision.
|
46
|
+
:py:class:`~.environment`` provides a context for model construction and
|
47
|
+
computation. In this temporal environment, models are constructed with the given
|
48
|
+
``mode``, ``dt``, ``int_``, etc., environment settings.
|
49
|
+
|
50
|
+
For instance::
|
51
|
+
|
52
|
+
>>> import brainstate as bst
|
53
|
+
>>> with bst.environ.context(dt=0.1) as env:
|
54
|
+
... dt = bst.environ.get('dt')
|
55
|
+
... print(env)
|
56
|
+
|
57
|
+
"""
|
58
|
+
if 'platform' in kwargs:
|
59
|
+
raise ValueError('Cannot set platform in environment context. '
|
60
|
+
'Please use set_platform() or set() for the global setting.')
|
61
|
+
if 'host_device_count' in kwargs:
|
62
|
+
raise ValueError('Cannot set host_device_count in environment context. '
|
63
|
+
'Please use set_host_device_count() or set() for the global setting.')
|
64
|
+
if 'precision' in kwargs:
|
65
|
+
last_precision = get_precision()
|
66
|
+
_set_jax_precision(kwargs['precision'])
|
67
|
+
|
68
|
+
try:
|
69
|
+
# update the current environment
|
70
|
+
for k, v in kwargs.items():
|
71
|
+
_environment_contexts[k].append(v)
|
72
|
+
# yield the current all environment information
|
73
|
+
yield all()
|
74
|
+
finally:
|
75
|
+
for k, v in kwargs.items():
|
76
|
+
_environment_contexts[k].pop()
|
77
|
+
if 'precision' in kwargs:
|
78
|
+
_set_jax_precision(last_precision)
|
79
|
+
|
80
|
+
|
81
|
+
def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
|
82
|
+
"""
|
83
|
+
Get one of the default computation environment.
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
item: Any
|
88
|
+
The default computation environment.
|
89
|
+
"""
|
90
|
+
if key == 'platform':
|
91
|
+
return get_platform()
|
92
|
+
|
93
|
+
if key == 'host_device_count':
|
94
|
+
return get_host_device_count()
|
95
|
+
|
96
|
+
if key in _environment_contexts:
|
97
|
+
if len(_environment_contexts[key]) > 0:
|
98
|
+
return _environment_contexts[key][-1]
|
99
|
+
if key in _environment_defaults:
|
100
|
+
return _environment_defaults[key]
|
101
|
+
|
102
|
+
if default is _NOT_PROVIDE:
|
103
|
+
if desc is not None:
|
104
|
+
raise KeyError(
|
105
|
+
f"'{key}' is not found in the context. \n"
|
106
|
+
f"You can set it by `brainstate.share.context({key}=value)` "
|
107
|
+
f"locally or `brainstate.share.set({key}=value)` globally. \n"
|
108
|
+
f"Description: {desc}"
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
raise KeyError(
|
112
|
+
f"'{key}' is not found in the context. \n"
|
113
|
+
f"You can set it by `brainstate.share.context({key}=value)` "
|
114
|
+
f"locally or `brainstate.share.set({key}=value)` globally."
|
115
|
+
)
|
116
|
+
return default
|
117
|
+
|
118
|
+
|
119
|
+
def all() -> dict:
|
120
|
+
"""
|
121
|
+
Get all the current default computation environment.
|
122
|
+
"""
|
123
|
+
r = dict()
|
124
|
+
for k, v in _environment_contexts.items():
|
125
|
+
if v:
|
126
|
+
r[k] = v[-1]
|
127
|
+
for k, v in _environment_defaults.items():
|
128
|
+
if k not in r:
|
129
|
+
r[k] = v
|
130
|
+
return r
|
131
|
+
|
132
|
+
|
133
|
+
def get_dt():
|
134
|
+
"""Get the numerical integrator precision.
|
135
|
+
|
136
|
+
Returns
|
137
|
+
-------
|
138
|
+
dt : float
|
139
|
+
Numerical integration precision.
|
140
|
+
"""
|
141
|
+
return get('dt')
|
142
|
+
|
143
|
+
|
144
|
+
def get_mode() -> Mode:
|
145
|
+
"""Get the default computing mode.
|
146
|
+
|
147
|
+
References
|
148
|
+
----------
|
149
|
+
mode: Mode
|
150
|
+
The default computing mode.
|
151
|
+
"""
|
152
|
+
return get('mode')
|
153
|
+
|
154
|
+
|
155
|
+
def get_mem_scaling() -> MemScaling:
|
156
|
+
"""Get the default computing membrane_scaling.
|
157
|
+
|
158
|
+
Returns
|
159
|
+
-------
|
160
|
+
membrane_scaling: MemScaling
|
161
|
+
The default computing membrane_scaling.
|
162
|
+
"""
|
163
|
+
return get('mem_scaling')
|
164
|
+
|
165
|
+
|
166
|
+
def get_platform() -> str:
|
167
|
+
"""Get the computing platform.
|
168
|
+
|
169
|
+
Returns
|
170
|
+
-------
|
171
|
+
platform: str
|
172
|
+
Either 'cpu', 'gpu' or 'tpu'.
|
173
|
+
"""
|
174
|
+
return devices()[0].platform
|
175
|
+
|
176
|
+
|
177
|
+
def get_host_device_count():
|
178
|
+
"""
|
179
|
+
Get the number of host devices.
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
n: int
|
184
|
+
The number of host devices.
|
185
|
+
"""
|
186
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
187
|
+
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
188
|
+
return int(match.group(1)) if match else 1
|
189
|
+
|
190
|
+
|
191
|
+
def get_precision() -> int:
|
192
|
+
"""
|
193
|
+
Get the default precision.
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
precision: int
|
198
|
+
The default precision.
|
199
|
+
"""
|
200
|
+
return get('precision')
|
201
|
+
|
202
|
+
|
203
|
+
def set(
|
204
|
+
platform: str = None,
|
205
|
+
host_device_count: int = None,
|
206
|
+
mem_scaling: MemScaling = None,
|
207
|
+
precision: int = None,
|
208
|
+
mode: Mode = None,
|
209
|
+
**kwargs
|
210
|
+
):
|
211
|
+
"""
|
212
|
+
Set the global default computation environment.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
216
|
+
host_device_count: int. The number of host devices.
|
217
|
+
mem_scaling: MemScaling. The membrane scaling.
|
218
|
+
precision: int. The default precision.
|
219
|
+
mode: Mode. The computing mode.
|
220
|
+
**kwargs: dict. Other environment settings.
|
221
|
+
"""
|
222
|
+
if platform is not None:
|
223
|
+
set_platform(platform)
|
224
|
+
if host_device_count is not None:
|
225
|
+
set_host_device_count(host_device_count)
|
226
|
+
if mem_scaling is not None:
|
227
|
+
assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
|
228
|
+
kwargs['mem_scaling'] = mem_scaling
|
229
|
+
if precision is not None:
|
230
|
+
_set_jax_precision(precision)
|
231
|
+
kwargs['precision'] = precision
|
232
|
+
if mode is not None:
|
233
|
+
assert isinstance(mode, Mode), 'mode must be a Mode instance.'
|
234
|
+
kwargs['mode'] = mode
|
235
|
+
_environment_defaults.update(kwargs)
|
236
|
+
|
237
|
+
|
238
|
+
def set_host_device_count(n):
|
239
|
+
"""
|
240
|
+
By default, XLA considers all CPU cores as one device. This utility tells XLA
|
241
|
+
that there are `n` host (CPU) devices available to use. As a consequence, this
|
242
|
+
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
243
|
+
|
244
|
+
.. note:: This utility only takes effect at the beginning of your program.
|
245
|
+
Under the hood, this sets the environment variable
|
246
|
+
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
247
|
+
`[num_device]` is the desired number of CPU devices `n`.
|
248
|
+
|
249
|
+
.. warning:: Our understanding of the side effects of using the
|
250
|
+
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
|
251
|
+
observe some strange phenomenon when using this utility, please let us
|
252
|
+
know through our issue or forum page. More information is available in this
|
253
|
+
`JAX issue <https://github.com/google/jax/issues/1408>`_.
|
254
|
+
|
255
|
+
:param int n: number of devices to use.
|
256
|
+
"""
|
257
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
258
|
+
xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
|
259
|
+
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
|
260
|
+
|
261
|
+
|
262
|
+
def set_platform(platform: str):
|
263
|
+
"""
|
264
|
+
Changes platform to CPU, GPU, or TPU. This utility only takes
|
265
|
+
effect at the beginning of your program.
|
266
|
+
"""
|
267
|
+
assert platform in ['cpu', 'gpu', 'tpu']
|
268
|
+
config.update("jax_platform_name", platform)
|
269
|
+
|
270
|
+
|
271
|
+
def _set_jax_precision(precision: int):
|
272
|
+
"""
|
273
|
+
Set the default precision.
|
274
|
+
|
275
|
+
Args:
|
276
|
+
precision: int. The default precision.
|
277
|
+
"""
|
278
|
+
assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
|
279
|
+
if precision == 64:
|
280
|
+
config.update("jax_enable_x64", True)
|
281
|
+
else:
|
282
|
+
config.update("jax_enable_x64", False)
|
283
|
+
|
284
|
+
|
285
|
+
@functools.lru_cache()
|
286
|
+
def _get_uint(precision: int):
|
287
|
+
if precision == 64:
|
288
|
+
return np.uint64
|
289
|
+
elif precision == 32:
|
290
|
+
return np.uint32
|
291
|
+
elif precision == 16:
|
292
|
+
return np.uint16
|
293
|
+
elif precision == 8:
|
294
|
+
return np.uint8
|
295
|
+
else:
|
296
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
297
|
+
|
298
|
+
|
299
|
+
@functools.lru_cache()
|
300
|
+
def _get_int(precision: int):
|
301
|
+
if precision == 64:
|
302
|
+
return np.int64
|
303
|
+
elif precision == 32:
|
304
|
+
return np.int32
|
305
|
+
elif precision == 16:
|
306
|
+
return np.int16
|
307
|
+
elif precision == 8:
|
308
|
+
return np.int8
|
309
|
+
else:
|
310
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
311
|
+
|
312
|
+
|
313
|
+
@functools.lru_cache()
|
314
|
+
def _get_float(precision: int):
|
315
|
+
if precision == 64:
|
316
|
+
return np.float64
|
317
|
+
elif precision == 32:
|
318
|
+
return np.float32
|
319
|
+
elif precision == 16:
|
320
|
+
return jnp.bfloat16
|
321
|
+
# return np.float16
|
322
|
+
else:
|
323
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
324
|
+
|
325
|
+
|
326
|
+
@functools.lru_cache()
|
327
|
+
def _get_complex(precision: int):
|
328
|
+
if precision == 64:
|
329
|
+
return np.complex128
|
330
|
+
elif precision == 32:
|
331
|
+
return np.complex64
|
332
|
+
elif precision == 16:
|
333
|
+
return np.complex32
|
334
|
+
else:
|
335
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
336
|
+
|
337
|
+
|
338
|
+
def dftype() -> DTypeLike:
|
339
|
+
"""
|
340
|
+
Default floating data type.
|
341
|
+
"""
|
342
|
+
return _get_float(get_precision())
|
343
|
+
|
344
|
+
|
345
|
+
def ditype() -> DTypeLike:
|
346
|
+
"""
|
347
|
+
Default integer data type.
|
348
|
+
"""
|
349
|
+
return _get_int(get_precision())
|
350
|
+
|
351
|
+
|
352
|
+
def dutype() -> DTypeLike:
|
353
|
+
"""
|
354
|
+
Default unsigned integer data type.
|
355
|
+
"""
|
356
|
+
return _get_uint(get_precision())
|
357
|
+
|
358
|
+
|
359
|
+
def dctype() -> DTypeLike:
|
360
|
+
"""
|
361
|
+
Default complex data type.
|
362
|
+
"""
|
363
|
+
return _get_complex(get_precision())
|
364
|
+
|
365
|
+
|
366
|
+
def tolerance():
|
367
|
+
if get_precision() == 64:
|
368
|
+
return jnp.array(1e-12, dtype=np.float64)
|
369
|
+
elif get_precision() == 32:
|
370
|
+
return jnp.array(1e-5, dtype=np.float32)
|
371
|
+
else:
|
372
|
+
return jnp.array(1e-2, dtype=np.float16)
|
373
|
+
|
374
|
+
|
375
|
+
set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from ._activations import *
|
18
|
+
from ._activations import __all__ as __activations_all__
|
19
|
+
from ._normalization import *
|
20
|
+
from ._normalization import __all__ as __others_all__
|
21
|
+
from ._spikes import *
|
22
|
+
from ._spikes import __all__ as __spikes_all__
|
23
|
+
|
24
|
+
__all__ = __spikes_all__ + __others_all__ + __activations_all__
|
25
|
+
|