brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -2
- brainstate/_module.py +102 -67
- brainstate/_state.py +2 -2
- brainstate/_visualization.py +47 -0
- brainstate/environ.py +116 -9
- brainstate/environ_test.py +56 -0
- brainstate/functional/_activations.py +134 -56
- brainstate/functional/_activations_test.py +331 -0
- brainstate/functional/_normalization.py +21 -10
- brainstate/init/_generic.py +4 -2
- brainstate/mixin.py +1 -1
- brainstate/nn/__init__.py +7 -2
- brainstate/nn/_base.py +2 -2
- brainstate/nn/_connections.py +4 -4
- brainstate/nn/_dynamics.py +5 -5
- brainstate/nn/_elementwise.py +9 -9
- brainstate/nn/_embedding.py +3 -3
- brainstate/nn/_normalizations.py +3 -3
- brainstate/nn/_others.py +2 -2
- brainstate/nn/_poolings.py +6 -6
- brainstate/nn/_rate_rnns.py +1 -1
- brainstate/nn/_readout.py +1 -1
- brainstate/nn/_synouts.py +1 -1
- brainstate/nn/event/__init__.py +25 -0
- brainstate/nn/event/_misc.py +34 -0
- brainstate/nn/event/csr.py +312 -0
- brainstate/nn/event/csr_test.py +118 -0
- brainstate/nn/event/fixed_probability.py +276 -0
- brainstate/nn/event/fixed_probability_test.py +127 -0
- brainstate/nn/event/linear.py +220 -0
- brainstate/nn/event/linear_test.py +111 -0
- brainstate/nn/metrics.py +390 -0
- brainstate/optim/__init__.py +5 -1
- brainstate/optim/_optax_optimizer.py +208 -0
- brainstate/optim/_optax_optimizer_test.py +14 -0
- brainstate/random/__init__.py +24 -0
- brainstate/{random.py → random/_rand_funs.py} +7 -1596
- brainstate/random/_rand_seed.py +169 -0
- brainstate/random/_rand_state.py +1491 -0
- brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
- brainstate/{random_test.py → random/random_test.py} +208 -191
- brainstate/transform/_jit.py +1 -1
- brainstate/transform/_jit_test.py +19 -0
- brainstate/transform/_make_jaxpr.py +1 -1
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
- brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
- brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -30,6 +30,8 @@ from . import surrogate
|
|
30
30
|
from . import transform
|
31
31
|
from . import typing
|
32
32
|
from . import util
|
33
|
+
from ._visualization import *
|
34
|
+
from ._visualization import __all__ as _visualization_all
|
33
35
|
from ._module import *
|
34
36
|
from ._module import __all__ as _module_all
|
35
37
|
from ._state import *
|
@@ -39,6 +41,6 @@ __all__ = (
|
|
39
41
|
['environ', 'share', 'nn', 'optim', 'random',
|
40
42
|
'surrogate', 'functional', 'init',
|
41
43
|
'mixin', 'transform', 'util', 'typing'] +
|
42
|
-
_module_all + _state_all
|
44
|
+
_module_all + _state_all + _visualization_all
|
43
45
|
)
|
44
|
-
del _module_all, _state_all
|
46
|
+
del _module_all, _state_all, _visualization_all
|
brainstate/_module.py
CHANGED
@@ -56,14 +56,13 @@ import jax
|
|
56
56
|
import jax.numpy as jnp
|
57
57
|
import numpy as np
|
58
58
|
|
59
|
-
from
|
60
|
-
from ._state import State, StateDictManager, visible_state_dict
|
61
|
-
from ._utils import set_module_as
|
62
|
-
from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
|
63
|
-
from .transform import jit_error_if
|
64
|
-
from .typing import Size, ArrayLike, PyTree
|
65
|
-
from .util import unique_name, DictManager, get_unique_name
|
66
|
-
|
59
|
+
from brainstate import environ
|
60
|
+
from brainstate._state import State, StateDictManager, visible_state_dict
|
61
|
+
from brainstate._utils import set_module_as
|
62
|
+
from brainstate.mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
|
63
|
+
from brainstate.transform import jit_error_if
|
64
|
+
from brainstate.typing import Size, ArrayLike, PyTree
|
65
|
+
from brainstate.util import unique_name, DictManager, get_unique_name
|
67
66
|
|
68
67
|
delay_identifier = '_*_delay_of_'
|
69
68
|
_DELAY_ROTATE = 'rotation'
|
@@ -94,7 +93,7 @@ __all__ = [
|
|
94
93
|
]
|
95
94
|
|
96
95
|
|
97
|
-
class
|
96
|
+
class Object:
|
98
97
|
"""
|
99
98
|
The Module class for the whole ecosystem.
|
100
99
|
|
@@ -118,43 +117,9 @@ class Module(object):
|
|
118
117
|
# the excluded nodes
|
119
118
|
_invisible_nodes: Tuple[str, ...] = ()
|
120
119
|
|
121
|
-
# # the supported computing modes
|
122
|
-
# supported_modes: Optional[Sequence[Mode]] = None
|
123
|
-
|
124
|
-
def __init__(self, name: str = None, mode: Mode = None):
|
125
|
-
super().__init__()
|
126
|
-
|
127
|
-
# check whether the object has a unique name.
|
128
|
-
self._name = unique_name(self=self, name=name)
|
129
|
-
|
130
|
-
# mode setting
|
131
|
-
self._mode = None
|
132
|
-
self.mode = mode if mode is not None else environ.get('mode')
|
133
|
-
|
134
120
|
def __repr__(self):
|
135
121
|
return f'{self.__class__.__name__}'
|
136
122
|
|
137
|
-
@property
|
138
|
-
def name(self):
|
139
|
-
"""Name of the model."""
|
140
|
-
return self._name
|
141
|
-
|
142
|
-
@name.setter
|
143
|
-
def name(self, name: str = None):
|
144
|
-
raise AttributeError('The name of the model is read-only.')
|
145
|
-
|
146
|
-
@property
|
147
|
-
def mode(self):
|
148
|
-
"""Mode of the model, which is useful to control the multiple behaviors of the model."""
|
149
|
-
return self._mode
|
150
|
-
|
151
|
-
@mode.setter
|
152
|
-
def mode(self, value):
|
153
|
-
if not isinstance(value, Mode):
|
154
|
-
raise ValueError(f'Must be instance of {Mode.__name__}, '
|
155
|
-
f'but we got {type(value)}: {value}')
|
156
|
-
self._mode = value
|
157
|
-
|
158
123
|
def states(
|
159
124
|
self,
|
160
125
|
method: str = 'absolute',
|
@@ -238,30 +203,6 @@ class Module(object):
|
|
238
203
|
nodes = nodes.unique()
|
239
204
|
return nodes
|
240
205
|
|
241
|
-
def update(self, *args, **kwargs):
|
242
|
-
"""
|
243
|
-
The function to specify the updating rule.
|
244
|
-
"""
|
245
|
-
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
|
246
|
-
f'implement "update" function.')
|
247
|
-
|
248
|
-
def __call__(self, *args, **kwargs):
|
249
|
-
return self.update(*args, **kwargs)
|
250
|
-
|
251
|
-
def __rrshift__(self, other):
|
252
|
-
"""
|
253
|
-
Support using right shift operator to call modules.
|
254
|
-
|
255
|
-
Examples
|
256
|
-
--------
|
257
|
-
|
258
|
-
>>> import brainstate as bst
|
259
|
-
>>> x = bst.random.rand((10, 10))
|
260
|
-
>>> l = bst.nn.Activation(jax.numpy.tanh)
|
261
|
-
>>> y = x >> l
|
262
|
-
"""
|
263
|
-
return self.__call__(other)
|
264
|
-
|
265
206
|
def init_state(self, *args, **kwargs):
|
266
207
|
"""
|
267
208
|
State initialization function.
|
@@ -289,6 +230,23 @@ class Module(object):
|
|
289
230
|
missing_keys = list(keys2 - keys1)
|
290
231
|
return unexpected_keys, missing_keys
|
291
232
|
|
233
|
+
def __treescope_repr__(self, path, subtree_renderer):
|
234
|
+
import treescope # type: ignore[import-not-found,import-untyped]
|
235
|
+
children = {}
|
236
|
+
for name, value in vars(self).items():
|
237
|
+
if name.startswith('_'):
|
238
|
+
continue
|
239
|
+
children[name] = value
|
240
|
+
return treescope.repr_lib.render_object_constructor(
|
241
|
+
object_type=type(self),
|
242
|
+
attributes=children,
|
243
|
+
path=path,
|
244
|
+
subtree_renderer=subtree_renderer,
|
245
|
+
color=treescope.formatting_util.color_from_string(
|
246
|
+
type(self).__qualname__
|
247
|
+
)
|
248
|
+
)
|
249
|
+
|
292
250
|
|
293
251
|
def _find_nodes(self, method: str = 'absolute', level=-1, include_self=True, _lid=0, _edges=None) -> DictManager:
|
294
252
|
if _edges is None:
|
@@ -374,6 +332,83 @@ def _add_node_relative(self, k, v, _paths, gather, nodes):
|
|
374
332
|
nodes.append((k, v))
|
375
333
|
|
376
334
|
|
335
|
+
class Module(Object):
|
336
|
+
"""
|
337
|
+
The Module class for the whole ecosystem.
|
338
|
+
|
339
|
+
The ``Module`` is the base class for all the objects in the ecosystem. It
|
340
|
+
provides the basic functionalities for the objects, including:
|
341
|
+
|
342
|
+
- ``states()``: Collect all states in this node and the children nodes.
|
343
|
+
- ``nodes()``: Collect all children nodes.
|
344
|
+
- ``update()``: The function to specify the updating rule.
|
345
|
+
- ``init_state()``: State initialization function.
|
346
|
+
- ``save_state()``: Save states as a dictionary.
|
347
|
+
- ``load_state()``: Load states from the external objects.
|
348
|
+
|
349
|
+
"""
|
350
|
+
|
351
|
+
__module__ = 'brainstate'
|
352
|
+
|
353
|
+
def __init__(self, name: str = None, mode: Mode = None):
|
354
|
+
super().__init__()
|
355
|
+
|
356
|
+
# check whether the object has a unique name.
|
357
|
+
self._name = unique_name(self=self, name=name)
|
358
|
+
|
359
|
+
# mode setting
|
360
|
+
self._mode = None
|
361
|
+
self.mode = mode if mode is not None else environ.get('mode')
|
362
|
+
|
363
|
+
def __repr__(self):
|
364
|
+
return f'{self.__class__.__name__}'
|
365
|
+
|
366
|
+
@property
|
367
|
+
def name(self):
|
368
|
+
"""Name of the model."""
|
369
|
+
return self._name
|
370
|
+
|
371
|
+
@name.setter
|
372
|
+
def name(self, name: str = None):
|
373
|
+
raise AttributeError('The name of the model is read-only.')
|
374
|
+
|
375
|
+
@property
|
376
|
+
def mode(self):
|
377
|
+
"""Mode of the model, which is useful to control the multiple behaviors of the model."""
|
378
|
+
return self._mode
|
379
|
+
|
380
|
+
@mode.setter
|
381
|
+
def mode(self, value):
|
382
|
+
if not isinstance(value, Mode):
|
383
|
+
raise ValueError(f'Must be instance of {Mode.__name__}, '
|
384
|
+
f'but we got {type(value)}: {value}')
|
385
|
+
self._mode = value
|
386
|
+
|
387
|
+
def update(self, *args, **kwargs):
|
388
|
+
"""
|
389
|
+
The function to specify the updating rule.
|
390
|
+
"""
|
391
|
+
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
|
392
|
+
f'implement "update" function.')
|
393
|
+
|
394
|
+
def __call__(self, *args, **kwargs):
|
395
|
+
return self.update(*args, **kwargs)
|
396
|
+
|
397
|
+
def __rrshift__(self, other):
|
398
|
+
"""
|
399
|
+
Support using right shift operator to call modules.
|
400
|
+
|
401
|
+
Examples
|
402
|
+
--------
|
403
|
+
|
404
|
+
>>> import brainstate as bst
|
405
|
+
>>> x = bst.random.rand((10, 10))
|
406
|
+
>>> l = bst.nn.Dropout(0.5)
|
407
|
+
>>> y = x >> l
|
408
|
+
"""
|
409
|
+
return self.__call__(other)
|
410
|
+
|
411
|
+
|
377
412
|
class Projection(Module):
|
378
413
|
"""
|
379
414
|
Base class to model synaptic projections.
|
brainstate/_state.py
CHANGED
@@ -22,8 +22,8 @@ import numpy as np
|
|
22
22
|
from jax.api_util import shaped_abstractify
|
23
23
|
from jax.extend import source_info_util
|
24
24
|
|
25
|
-
from .typing import ArrayLike, PyTree
|
26
|
-
from .util import DictManager
|
25
|
+
from brainstate.typing import ArrayLike, PyTree
|
26
|
+
from brainstate.util import DictManager
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'State', 'ShortTermState', 'LongTermState', 'ParamState',
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
'display',
|
19
|
+
]
|
20
|
+
|
21
|
+
|
22
|
+
import importlib.util
|
23
|
+
|
24
|
+
treescope_installed = importlib.util.find_spec('treescope') is not None
|
25
|
+
try:
|
26
|
+
from IPython import get_ipython
|
27
|
+
|
28
|
+
in_ipython = get_ipython() is not None
|
29
|
+
except ImportError:
|
30
|
+
in_ipython = False
|
31
|
+
|
32
|
+
|
33
|
+
def display(*args):
|
34
|
+
"""Display the given objects using the Treescope pretty-printer.
|
35
|
+
|
36
|
+
If treescope is not installed or the code is not running in IPython,
|
37
|
+
``display`` will print the objects instead.
|
38
|
+
"""
|
39
|
+
if not treescope_installed or not in_ipython:
|
40
|
+
for x in args:
|
41
|
+
print(x)
|
42
|
+
return
|
43
|
+
|
44
|
+
import treescope # type: ignore[import-not-found,import-untyped]
|
45
|
+
|
46
|
+
for x in args:
|
47
|
+
treescope.display(x, ignore_exceptions=True, autovisualize=True)
|
brainstate/environ.py
CHANGED
@@ -16,12 +16,14 @@ from .mixin import Mode
|
|
16
16
|
from .util import MemScaling, IdMemScaling
|
17
17
|
|
18
18
|
__all__ = [
|
19
|
-
|
20
|
-
'set_host_device_count', 'set_platform',
|
21
|
-
|
22
|
-
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
23
|
-
|
19
|
+
# functions for environment settings
|
20
|
+
'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
|
21
|
+
# functions for getting default behaviors
|
22
|
+
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
23
|
+
# functions for default data types
|
24
24
|
'dftype', 'ditype', 'dutype', 'dctype',
|
25
|
+
# others
|
26
|
+
'tolerance', 'register_default_behavior',
|
25
27
|
]
|
26
28
|
|
27
29
|
# Default, there are several shared arguments in the global context.
|
@@ -57,8 +59,9 @@ def context(**kwargs):
|
|
57
59
|
|
58
60
|
"""
|
59
61
|
if 'platform' in kwargs:
|
60
|
-
raise ValueError('
|
61
|
-
'
|
62
|
+
raise ValueError('\n'
|
63
|
+
'Cannot set platform in "context" environment. \n'
|
64
|
+
'You should set platform in the global environment by "set_platform()" or "set()".')
|
62
65
|
if 'host_device_count' in kwargs:
|
63
66
|
raise ValueError('Cannot set host_device_count in environment context. '
|
64
67
|
'Please use set_host_device_count() or set() for the global setting.')
|
@@ -135,6 +138,11 @@ def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
|
|
135
138
|
def all() -> dict:
|
136
139
|
"""
|
137
140
|
Get all the current default computation environment.
|
141
|
+
|
142
|
+
Returns
|
143
|
+
-------
|
144
|
+
r: dict
|
145
|
+
The current default computation environment.
|
138
146
|
"""
|
139
147
|
r = dict()
|
140
148
|
for k, v in _environment_contexts.items():
|
@@ -227,6 +235,8 @@ def set(
|
|
227
235
|
"""
|
228
236
|
Set the global default computation environment.
|
229
237
|
|
238
|
+
|
239
|
+
|
230
240
|
Args:
|
231
241
|
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
232
242
|
host_device_count: int. The number of host devices.
|
@@ -290,6 +300,12 @@ def set_platform(platform: str):
|
|
290
300
|
"""
|
291
301
|
Changes platform to CPU, GPU, or TPU. This utility only takes
|
292
302
|
effect at the beginning of your program.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
306
|
+
|
307
|
+
Raises:
|
308
|
+
ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
|
293
309
|
"""
|
294
310
|
assert platform in ['cpu', 'gpu', 'tpu']
|
295
311
|
config.update("jax_platform_name", platform)
|
@@ -369,6 +385,23 @@ def _get_complex(precision: int):
|
|
369
385
|
def dftype() -> DTypeLike:
|
370
386
|
"""
|
371
387
|
Default floating data type.
|
388
|
+
|
389
|
+
This function returns the default floating data type based on the current precision.
|
390
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
391
|
+
you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
|
392
|
+
|
393
|
+
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
394
|
+
|
395
|
+
>>> import brainstate as bst
|
396
|
+
>>> import numpy as np
|
397
|
+
>>> with bst.environ.context(precision=32):
|
398
|
+
... a = np.zeros(1, dtype=bst.environ.dftype())
|
399
|
+
>>> print(a.dtype)
|
400
|
+
|
401
|
+
Returns
|
402
|
+
-------
|
403
|
+
float_dtype: DTypeLike
|
404
|
+
The default floating data type.
|
372
405
|
"""
|
373
406
|
return _get_float(get_precision())
|
374
407
|
|
@@ -376,6 +409,24 @@ def dftype() -> DTypeLike:
|
|
376
409
|
def ditype() -> DTypeLike:
|
377
410
|
"""
|
378
411
|
Default integer data type.
|
412
|
+
|
413
|
+
This function returns the default integer data type based on the current precision.
|
414
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
415
|
+
you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
|
416
|
+
|
417
|
+
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
418
|
+
|
419
|
+
>>> import brainstate as bst
|
420
|
+
>>> import numpy as np
|
421
|
+
>>> with bst.environ.context(precision=32):
|
422
|
+
... a = np.zeros(1, dtype=bst.environ.ditype())
|
423
|
+
>>> print(a.dtype)
|
424
|
+
int32
|
425
|
+
|
426
|
+
Returns
|
427
|
+
-------
|
428
|
+
int_dtype: DTypeLike
|
429
|
+
The default integer data type.
|
379
430
|
"""
|
380
431
|
return _get_int(get_precision())
|
381
432
|
|
@@ -383,6 +434,25 @@ def ditype() -> DTypeLike:
|
|
383
434
|
def dutype() -> DTypeLike:
|
384
435
|
"""
|
385
436
|
Default unsigned integer data type.
|
437
|
+
|
438
|
+
This function returns the default unsigned integer data type based on the current precision.
|
439
|
+
If you want the data dtype is changed with the setting of the precision
|
440
|
+
by ``brainstate.environ.set(precision)``, you can use this function to get the default
|
441
|
+
unsigned integer data type, and create the data by using ``dtype=dutype()``.
|
442
|
+
|
443
|
+
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
444
|
+
|
445
|
+
>>> import brainstate as bst
|
446
|
+
>>> import numpy as np
|
447
|
+
>>> with bst.environ.context(precision=32):
|
448
|
+
... a = np.zeros(1, dtype=bst.environ.dutype())
|
449
|
+
>>> print(a.dtype)
|
450
|
+
uint32
|
451
|
+
|
452
|
+
Returns
|
453
|
+
-------
|
454
|
+
uint_dtype: DTypeLike
|
455
|
+
The default unsigned integer data type.
|
386
456
|
"""
|
387
457
|
return _get_uint(get_precision())
|
388
458
|
|
@@ -390,6 +460,24 @@ def dutype() -> DTypeLike:
|
|
390
460
|
def dctype() -> DTypeLike:
|
391
461
|
"""
|
392
462
|
Default complex data type.
|
463
|
+
|
464
|
+
This function returns the default complex data type based on the current precision.
|
465
|
+
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
466
|
+
you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
|
467
|
+
|
468
|
+
For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
|
469
|
+
|
470
|
+
>>> import brainstate as bst
|
471
|
+
>>> import numpy as np
|
472
|
+
>>> with bst.environ.context(precision=32):
|
473
|
+
... a = np.zeros(1, dtype=bst.environ.dctype())
|
474
|
+
>>> print(a.dtype)
|
475
|
+
complex64
|
476
|
+
|
477
|
+
Returns
|
478
|
+
-------
|
479
|
+
complex_dtype: DTypeLike
|
480
|
+
The default complex data type.
|
393
481
|
"""
|
394
482
|
return _get_complex(get_precision())
|
395
483
|
|
@@ -405,7 +493,27 @@ def tolerance():
|
|
405
493
|
|
406
494
|
def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
|
407
495
|
"""
|
408
|
-
Register a default behavior for a specific key.
|
496
|
+
Register a default behavior for a specific global key parameter.
|
497
|
+
|
498
|
+
For example, you can register a default behavior for the key 'dt' by::
|
499
|
+
|
500
|
+
>>> import brainstate as bst
|
501
|
+
>>> def dt_behavior(dt):
|
502
|
+
... print(f'Set the default dt to {dt}.')
|
503
|
+
...
|
504
|
+
>>> bst.environ.register_default_behavior('dt', dt_behavior)
|
505
|
+
|
506
|
+
Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
|
507
|
+
`dt_behavior` will be called with
|
508
|
+
`dt_behavior(0.1)`.
|
509
|
+
|
510
|
+
>>> bst.environ.set(dt=0.1)
|
511
|
+
Set the default dt to 0.1.
|
512
|
+
>>> with bst.environ.context(dt=0.2):
|
513
|
+
... pass
|
514
|
+
Set the default dt to 0.2.
|
515
|
+
Set the default dt to 0.1.
|
516
|
+
|
409
517
|
|
410
518
|
Args:
|
411
519
|
key: str. The key to register.
|
@@ -421,4 +529,3 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo
|
|
421
529
|
|
422
530
|
|
423
531
|
set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
|
424
|
-
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate as bst
|
22
|
+
|
23
|
+
|
24
|
+
class TestEnviron(unittest.TestCase):
|
25
|
+
def test_precision(self):
|
26
|
+
with bst.environ.context(precision=64):
|
27
|
+
a = bst.random.randn(1)
|
28
|
+
self.assertEqual(a.dtype, jnp.float64)
|
29
|
+
|
30
|
+
with bst.environ.context(precision=32):
|
31
|
+
a = bst.random.randn(1)
|
32
|
+
self.assertEqual(a.dtype, jnp.float32)
|
33
|
+
|
34
|
+
with bst.environ.context(precision=16):
|
35
|
+
a = bst.random.randn(1)
|
36
|
+
self.assertEqual(a.dtype, jnp.bfloat16)
|
37
|
+
|
38
|
+
def test_platform(self):
|
39
|
+
with self.assertRaises(ValueError):
|
40
|
+
with bst.environ.context(platform='cpu'):
|
41
|
+
a = bst.random.randn(1)
|
42
|
+
self.assertEqual(a.device(), 'cpu')
|
43
|
+
|
44
|
+
def test_register_default_behavior(self):
|
45
|
+
dt_ = 0.1
|
46
|
+
|
47
|
+
def dt_behavior(dt):
|
48
|
+
nonlocal dt_
|
49
|
+
dt_ = dt
|
50
|
+
print(f'dt: {dt}')
|
51
|
+
|
52
|
+
bst.environ.register_default_behavior('dt', dt_behavior)
|
53
|
+
|
54
|
+
with bst.environ.context(dt=0.2):
|
55
|
+
self.assertEqual(dt_, 0.2)
|
56
|
+
self.assertEqual(dt_, 0.1)
|