brainstate 0.0.2.post20240814__py2.py3-none-any.whl → 0.0.2.post20240825__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/environ.py +54 -5
- {brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/METADATA +5 -5
- {brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/RECORD +6 -6
- {brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/WHEEL +1 -1
- {brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -6,7 +6,7 @@ import functools
|
|
6
6
|
import os
|
7
7
|
import re
|
8
8
|
from collections import defaultdict
|
9
|
-
from typing import Any
|
9
|
+
from typing import Any, Callable
|
10
10
|
|
11
11
|
import numpy as np
|
12
12
|
from jax import config, devices, numpy as jnp
|
@@ -20,7 +20,7 @@ __all__ = [
|
|
20
20
|
'set_host_device_count', 'set_platform',
|
21
21
|
'get_host_device_count', 'get_platform',
|
22
22
|
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
23
|
-
'tolerance',
|
23
|
+
'tolerance', 'register_default_behavior',
|
24
24
|
'dftype', 'ditype', 'dutype', 'dctype',
|
25
25
|
]
|
26
26
|
|
@@ -31,8 +31,9 @@ JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation
|
|
31
31
|
FIT = 'fit' # whether to fit the model.
|
32
32
|
|
33
33
|
_NOT_PROVIDE = object()
|
34
|
-
_environment_defaults = dict()
|
35
|
-
_environment_contexts = defaultdict(list)
|
34
|
+
_environment_defaults = dict() # default environment settings
|
35
|
+
_environment_contexts = defaultdict(list) # current environment settings
|
36
|
+
_environment_functions = dict() # environment functions
|
36
37
|
|
37
38
|
|
38
39
|
@contextlib.contextmanager
|
@@ -61,19 +62,34 @@ def context(**kwargs):
|
|
61
62
|
if 'host_device_count' in kwargs:
|
62
63
|
raise ValueError('Cannot set host_device_count in environment context. '
|
63
64
|
'Please use set_host_device_count() or set() for the global setting.')
|
65
|
+
|
64
66
|
if 'precision' in kwargs:
|
65
67
|
last_precision = get_precision()
|
66
68
|
_set_jax_precision(kwargs['precision'])
|
67
69
|
|
68
70
|
try:
|
69
|
-
# update the current environment
|
70
71
|
for k, v in kwargs.items():
|
72
|
+
|
73
|
+
# update the current environment
|
71
74
|
_environment_contexts[k].append(v)
|
75
|
+
|
76
|
+
# restore the environment functions
|
77
|
+
if k in _environment_functions:
|
78
|
+
_environment_functions[k](v)
|
79
|
+
|
72
80
|
# yield the current all environment information
|
73
81
|
yield all()
|
74
82
|
finally:
|
83
|
+
|
75
84
|
for k, v in kwargs.items():
|
85
|
+
|
86
|
+
# restore the current environment
|
76
87
|
_environment_contexts[k].pop()
|
88
|
+
|
89
|
+
# restore the environment functions
|
90
|
+
if k in _environment_functions:
|
91
|
+
_environment_functions[k](get(k))
|
92
|
+
|
77
93
|
if 'precision' in kwargs:
|
78
94
|
_set_jax_precision(last_precision)
|
79
95
|
|
@@ -232,8 +248,15 @@ def set(
|
|
232
248
|
if mode is not None:
|
233
249
|
assert isinstance(mode, Mode), 'mode must be a Mode instance.'
|
234
250
|
kwargs['mode'] = mode
|
251
|
+
|
252
|
+
# set default environment
|
235
253
|
_environment_defaults.update(kwargs)
|
236
254
|
|
255
|
+
# update the environment functions
|
256
|
+
for k, v in kwargs.items():
|
257
|
+
if k in _environment_functions:
|
258
|
+
_environment_functions[k](v)
|
259
|
+
|
237
260
|
|
238
261
|
def set_host_device_count(n):
|
239
262
|
"""
|
@@ -258,6 +281,10 @@ def set_host_device_count(n):
|
|
258
281
|
xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
|
259
282
|
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
|
260
283
|
|
284
|
+
# update the environment functions
|
285
|
+
if 'host_device_count' in _environment_functions:
|
286
|
+
_environment_functions['host_device_count'](n)
|
287
|
+
|
261
288
|
|
262
289
|
def set_platform(platform: str):
|
263
290
|
"""
|
@@ -267,6 +294,10 @@ def set_platform(platform: str):
|
|
267
294
|
assert platform in ['cpu', 'gpu', 'tpu']
|
268
295
|
config.update("jax_platform_name", platform)
|
269
296
|
|
297
|
+
# update the environment functions
|
298
|
+
if 'platform' in _environment_functions:
|
299
|
+
_environment_functions['platform'](platform)
|
300
|
+
|
270
301
|
|
271
302
|
def _set_jax_precision(precision: int):
|
272
303
|
"""
|
@@ -372,4 +403,22 @@ def tolerance():
|
|
372
403
|
return jnp.array(1e-2, dtype=np.float16)
|
373
404
|
|
374
405
|
|
406
|
+
def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
|
407
|
+
"""
|
408
|
+
Register a default behavior for a specific key.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
key: str. The key to register.
|
412
|
+
behavior: Callable. The behavior to register. It should be a callable.
|
413
|
+
replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
|
414
|
+
|
415
|
+
"""
|
416
|
+
assert isinstance(key, str), 'key must be a string.'
|
417
|
+
assert callable(behavior), 'behavior must be a callable.'
|
418
|
+
if not replace_if_exist:
|
419
|
+
assert key not in _environment_functions, f'{key} has been registered.'
|
420
|
+
_environment_functions[key] = behavior
|
421
|
+
|
422
|
+
|
375
423
|
set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
|
424
|
+
|
{brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/METADATA
RENAMED
@@ -1,14 +1,14 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.post20240825
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BDP
|
7
|
-
Author-email:
|
7
|
+
Author-email: BrainState Developers <chao.brain@qq.com>
|
8
8
|
License: Apache-2.0 license
|
9
9
|
Project-URL: homepage, http://github.com/brainpy
|
10
10
|
Project-URL: repository, http://github.com/brainpy/brainstate
|
11
|
-
Keywords:
|
11
|
+
Keywords: computational neuroscience,brain-inspired computation,brain dynamics programming
|
12
12
|
Classifier: Natural Language :: English
|
13
13
|
Classifier: Operating System :: OS Independent
|
14
14
|
Classifier: Development Status :: 4 - Beta
|
@@ -31,11 +31,11 @@ License-File: LICENSE
|
|
31
31
|
Requires-Dist: jax
|
32
32
|
Requires-Dist: jaxlib
|
33
33
|
Requires-Dist: numpy
|
34
|
-
Requires-Dist: brainunit
|
34
|
+
Requires-Dist: brainunit >=0.0.2
|
35
35
|
Provides-Extra: cpu
|
36
36
|
Requires-Dist: jaxlib ; extra == 'cpu'
|
37
37
|
Provides-Extra: cuda12
|
38
|
-
Requires-Dist: jaxlib[
|
38
|
+
Requires-Dist: jaxlib[cuda12] ; extra == 'cuda12'
|
39
39
|
Provides-Extra: testing
|
40
40
|
Requires-Dist: pytest ; extra == 'testing'
|
41
41
|
Provides-Extra: tpu
|
@@ -5,7 +5,7 @@ brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd
|
|
5
5
|
brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
|
6
6
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
7
7
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
8
|
-
brainstate/environ.py,sha256=
|
8
|
+
brainstate/environ.py,sha256=k0p1oyi9jbsPfuvqrPL-_zgSd7VW3LRs0LboxlaaIfc,11806
|
9
9
|
brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
|
10
10
|
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
11
|
brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
|
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
|
|
59
59
|
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
60
60
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
61
61
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate-0.0.2.
|
63
|
-
brainstate-0.0.2.
|
64
|
-
brainstate-0.0.2.
|
65
|
-
brainstate-0.0.2.
|
66
|
-
brainstate-0.0.2.
|
62
|
+
brainstate-0.0.2.post20240825.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
+
brainstate-0.0.2.post20240825.dist-info/METADATA,sha256=COfpxoCL7w1xGa1OFYFeANFLhAKmSioWVtmF_i2st34,3849
|
64
|
+
brainstate-0.0.2.post20240825.dist-info/WHEEL,sha256=GUeE9LxUgRABPG7YM0jCNs9cBsAIx0YAkzCB88PMLgc,109
|
65
|
+
brainstate-0.0.2.post20240825.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
+
brainstate-0.0.2.post20240825.dist-info/RECORD,,
|
File without changes
|
{brainstate-0.0.2.post20240814.dist-info → brainstate-0.0.2.post20240825.dist-info}/top_level.txt
RENAMED
File without changes
|