brainstate 0.0.2.post20240824__py2.py3-none-any.whl → 0.0.2.post20240825__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/environ.py CHANGED
@@ -6,7 +6,7 @@ import functools
6
6
  import os
7
7
  import re
8
8
  from collections import defaultdict
9
- from typing import Any
9
+ from typing import Any, Callable
10
10
 
11
11
  import numpy as np
12
12
  from jax import config, devices, numpy as jnp
@@ -20,7 +20,7 @@ __all__ = [
20
20
  'set_host_device_count', 'set_platform',
21
21
  'get_host_device_count', 'get_platform',
22
22
  'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
23
- 'tolerance',
23
+ 'tolerance', 'register_default_behavior',
24
24
  'dftype', 'ditype', 'dutype', 'dctype',
25
25
  ]
26
26
 
@@ -31,8 +31,9 @@ JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation
31
31
  FIT = 'fit' # whether to fit the model.
32
32
 
33
33
  _NOT_PROVIDE = object()
34
- _environment_defaults = dict()
35
- _environment_contexts = defaultdict(list)
34
+ _environment_defaults = dict() # default environment settings
35
+ _environment_contexts = defaultdict(list) # current environment settings
36
+ _environment_functions = dict() # environment functions
36
37
 
37
38
 
38
39
  @contextlib.contextmanager
@@ -61,19 +62,34 @@ def context(**kwargs):
61
62
  if 'host_device_count' in kwargs:
62
63
  raise ValueError('Cannot set host_device_count in environment context. '
63
64
  'Please use set_host_device_count() or set() for the global setting.')
65
+
64
66
  if 'precision' in kwargs:
65
67
  last_precision = get_precision()
66
68
  _set_jax_precision(kwargs['precision'])
67
69
 
68
70
  try:
69
- # update the current environment
70
71
  for k, v in kwargs.items():
72
+
73
+ # update the current environment
71
74
  _environment_contexts[k].append(v)
75
+
76
+ # restore the environment functions
77
+ if k in _environment_functions:
78
+ _environment_functions[k](v)
79
+
72
80
  # yield the current all environment information
73
81
  yield all()
74
82
  finally:
83
+
75
84
  for k, v in kwargs.items():
85
+
86
+ # restore the current environment
76
87
  _environment_contexts[k].pop()
88
+
89
+ # restore the environment functions
90
+ if k in _environment_functions:
91
+ _environment_functions[k](get(k))
92
+
77
93
  if 'precision' in kwargs:
78
94
  _set_jax_precision(last_precision)
79
95
 
@@ -232,8 +248,15 @@ def set(
232
248
  if mode is not None:
233
249
  assert isinstance(mode, Mode), 'mode must be a Mode instance.'
234
250
  kwargs['mode'] = mode
251
+
252
+ # set default environment
235
253
  _environment_defaults.update(kwargs)
236
254
 
255
+ # update the environment functions
256
+ for k, v in kwargs.items():
257
+ if k in _environment_functions:
258
+ _environment_functions[k](v)
259
+
237
260
 
238
261
  def set_host_device_count(n):
239
262
  """
@@ -258,6 +281,10 @@ def set_host_device_count(n):
258
281
  xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
259
282
  os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
260
283
 
284
+ # update the environment functions
285
+ if 'host_device_count' in _environment_functions:
286
+ _environment_functions['host_device_count'](n)
287
+
261
288
 
262
289
  def set_platform(platform: str):
263
290
  """
@@ -267,6 +294,10 @@ def set_platform(platform: str):
267
294
  assert platform in ['cpu', 'gpu', 'tpu']
268
295
  config.update("jax_platform_name", platform)
269
296
 
297
+ # update the environment functions
298
+ if 'platform' in _environment_functions:
299
+ _environment_functions['platform'](platform)
300
+
270
301
 
271
302
  def _set_jax_precision(precision: int):
272
303
  """
@@ -372,4 +403,22 @@ def tolerance():
372
403
  return jnp.array(1e-2, dtype=np.float16)
373
404
 
374
405
 
406
+ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
407
+ """
408
+ Register a default behavior for a specific key.
409
+
410
+ Args:
411
+ key: str. The key to register.
412
+ behavior: Callable. The behavior to register. It should be a callable.
413
+ replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
414
+
415
+ """
416
+ assert isinstance(key, str), 'key must be a string.'
417
+ assert callable(behavior), 'behavior must be a callable.'
418
+ if not replace_if_exist:
419
+ assert key not in _environment_functions, f'{key} has been registered.'
420
+ _environment_functions[key] = behavior
421
+
422
+
375
423
  set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
424
+
@@ -1,14 +1,14 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.2.post20240824
3
+ Version: 0.0.2.post20240825
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
7
- Author-email: BrainPy Team <chao.brain@qq.com>
7
+ Author-email: BrainState Developers <chao.brain@qq.com>
8
8
  License: Apache-2.0 license
9
9
  Project-URL: homepage, http://github.com/brainpy
10
10
  Project-URL: repository, http://github.com/brainpy/brainstate
11
- Keywords: brainpy,brain simulation,brain-inspired computing
11
+ Keywords: computational neuroscience,brain-inspired computation,brain dynamics programming
12
12
  Classifier: Natural Language :: English
13
13
  Classifier: Operating System :: OS Independent
14
14
  Classifier: Development Status :: 4 - Beta
@@ -31,7 +31,7 @@ License-File: LICENSE
31
31
  Requires-Dist: jax
32
32
  Requires-Dist: jaxlib
33
33
  Requires-Dist: numpy
34
- Requires-Dist: brainunit
34
+ Requires-Dist: brainunit >=0.0.2
35
35
  Provides-Extra: cpu
36
36
  Requires-Dist: jaxlib ; extra == 'cpu'
37
37
  Provides-Extra: cuda12
@@ -5,7 +5,7 @@ brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd
5
5
  brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
6
6
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
7
7
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
- brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
8
+ brainstate/environ.py,sha256=k0p1oyi9jbsPfuvqrPL-_zgSd7VW3LRs0LboxlaaIfc,11806
9
9
  brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
10
10
  brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
11
  brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
59
59
  brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.2.post20240824.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.2.post20240824.dist-info/METADATA,sha256=XmDSiVoXh250MvzBm2tJGfcNpRQ4FUrQTJsCvphGSSA,3801
64
- brainstate-0.0.2.post20240824.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.2.post20240824.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.2.post20240824.dist-info/RECORD,,
62
+ brainstate-0.0.2.post20240825.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.2.post20240825.dist-info/METADATA,sha256=COfpxoCL7w1xGa1OFYFeANFLhAKmSioWVtmF_i2st34,3849
64
+ brainstate-0.0.2.post20240825.dist-info/WHEEL,sha256=GUeE9LxUgRABPG7YM0jCNs9cBsAIx0YAkzCB88PMLgc,109
65
+ brainstate-0.0.2.post20240825.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.2.post20240825.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.38.4)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any