brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,7 +13,50 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
16
|
+
"""
|
17
|
+
Environment configuration and context management for BrainState.
|
18
|
+
|
19
|
+
This module provides comprehensive functionality for managing computational
|
20
|
+
environments, including platform selection, precision control, mode setting,
|
21
|
+
and context-based configuration management. It enables flexible configuration
|
22
|
+
of JAX-based computations with thread-safe context switching.
|
23
|
+
|
24
|
+
The module supports:
|
25
|
+
- Platform configuration (CPU, GPU, TPU)
|
26
|
+
- Precision control (8, 16, 32, 64 bit and bfloat16)
|
27
|
+
- Computation mode management
|
28
|
+
- Context-based temporary settings
|
29
|
+
- Default data type management
|
30
|
+
- Custom behavior registration
|
31
|
+
|
32
|
+
Examples
|
33
|
+
--------
|
34
|
+
Global environment configuration:
|
35
|
+
|
36
|
+
.. code-block:: python
|
37
|
+
|
38
|
+
>>> import brainstate as bs
|
39
|
+
>>> import brainstate.environ as env
|
40
|
+
>>>
|
41
|
+
>>> # Set global precision to 32-bit
|
42
|
+
>>> env.set(precision=32, dt=0.01, mode=bs.mixin.Training())
|
43
|
+
>>>
|
44
|
+
>>> # Get current settings
|
45
|
+
>>> print(env.get('precision')) # 32
|
46
|
+
>>> print(env.get('dt')) # 0.01
|
47
|
+
|
48
|
+
Context-based temporary settings:
|
49
|
+
|
50
|
+
.. code-block:: python
|
51
|
+
|
52
|
+
>>> import brainstate.environ as env
|
53
|
+
>>>
|
54
|
+
>>> # Temporarily change precision
|
55
|
+
>>> with env.context(precision=64, dt=0.001):
|
56
|
+
... high_precision_result = compute_something()
|
57
|
+
... print(env.get('precision')) # 64
|
58
|
+
>>> print(env.get('precision')) # Back to 32
|
59
|
+
"""
|
17
60
|
|
18
61
|
import contextlib
|
19
62
|
import dataclasses
|
@@ -21,543 +64,1432 @@ import functools
|
|
21
64
|
import os
|
22
65
|
import re
|
23
66
|
import threading
|
67
|
+
import warnings
|
24
68
|
from collections import defaultdict
|
25
|
-
from typing import Any, Callable, Dict, Hashable
|
69
|
+
from typing import Any, Callable, Dict, Hashable, Optional, Union, ContextManager, List
|
26
70
|
|
71
|
+
import brainunit as u
|
27
72
|
import numpy as np
|
28
73
|
from jax import config, devices, numpy as jnp
|
29
74
|
from jax.typing import DTypeLike
|
30
75
|
|
31
|
-
from .mixin import Mode
|
32
|
-
|
33
76
|
__all__ = [
|
34
|
-
#
|
35
|
-
'set',
|
36
|
-
|
37
|
-
'
|
38
|
-
|
39
|
-
'
|
40
|
-
|
41
|
-
|
77
|
+
# Core environment management
|
78
|
+
'set',
|
79
|
+
'get',
|
80
|
+
'all',
|
81
|
+
'pop',
|
82
|
+
'context',
|
83
|
+
'reset',
|
84
|
+
|
85
|
+
# Platform and device management
|
86
|
+
'set_platform',
|
87
|
+
'get_platform',
|
88
|
+
'set_host_device_count',
|
89
|
+
'get_host_device_count',
|
90
|
+
|
91
|
+
# Precision and data type management
|
92
|
+
'set_precision',
|
93
|
+
'get_precision',
|
94
|
+
'dftype',
|
95
|
+
'ditype',
|
96
|
+
'dutype',
|
97
|
+
'dctype',
|
98
|
+
|
99
|
+
# Mode and computation settings
|
100
|
+
'get_dt',
|
101
|
+
|
102
|
+
# Utility functions
|
103
|
+
'tolerance',
|
104
|
+
'register_default_behavior',
|
105
|
+
'unregister_default_behavior',
|
106
|
+
'list_registered_behaviors',
|
107
|
+
|
108
|
+
# Constants
|
109
|
+
'DEFAULT_PRECISION',
|
110
|
+
'SUPPORTED_PLATFORMS',
|
111
|
+
'SUPPORTED_PRECISIONS',
|
112
|
+
|
113
|
+
# Names
|
114
|
+
'I',
|
115
|
+
'T',
|
116
|
+
'DT',
|
117
|
+
'PRECISION',
|
118
|
+
'PLATFORM',
|
119
|
+
'HOST_DEVICE_COUNT',
|
120
|
+
'JIT_ERROR_CHECK',
|
121
|
+
'FIT',
|
42
122
|
]
|
43
123
|
|
44
|
-
#
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
FIT = 'fit' # whether to fit the model.
|
124
|
+
# Type definitions
|
125
|
+
# T = TypeVar('T')
|
126
|
+
PrecisionType = Union[int, str]
|
127
|
+
PlatformType = str
|
49
128
|
|
129
|
+
# Constants for environment keys
|
130
|
+
I = 'i' # Index of the current computation
|
131
|
+
T = 't' # Current time of the computation
|
132
|
+
DT = 'dt' # Time step for numerical integration
|
133
|
+
PRECISION = 'precision' # Numerical precision
|
134
|
+
PLATFORM = 'platform' # Computing platform
|
135
|
+
HOST_DEVICE_COUNT = 'host_device_count' # Number of host devices
|
136
|
+
JIT_ERROR_CHECK = 'jit_error_check' # JIT error checking flag
|
137
|
+
FIT = 'fit' # Model fitting flag
|
50
138
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
# current environment settings
|
56
|
-
contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
57
|
-
# environment functions
|
58
|
-
functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
139
|
+
# Default values
|
140
|
+
DEFAULT_PRECISION = 32
|
141
|
+
SUPPORTED_PLATFORMS = ('cpu', 'gpu', 'tpu')
|
142
|
+
SUPPORTED_PRECISIONS = (8, 16, 32, 64, 'bf16')
|
59
143
|
|
144
|
+
# Sentinel value for missing arguments
|
145
|
+
_NOT_PROVIDED = object()
|
60
146
|
|
61
|
-
DFAULT = DefaultContext()
|
62
|
-
_NOT_PROVIDE = object()
|
63
147
|
|
148
|
+
@dataclasses.dataclass
|
149
|
+
class EnvironmentState(threading.local):
|
150
|
+
"""
|
151
|
+
Thread-local storage for environment configuration.
|
64
152
|
|
65
|
-
|
66
|
-
|
67
|
-
r"""
|
68
|
-
Context-manager that sets a computing environment for brain dynamics computation.
|
153
|
+
This class maintains separate configuration states for different threads,
|
154
|
+
ensuring thread-safe environment management in concurrent applications.
|
69
155
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
156
|
+
Attributes
|
157
|
+
----------
|
158
|
+
settings : Dict[Hashable, Any]
|
159
|
+
Global default environment settings.
|
160
|
+
contexts : defaultdict[Hashable, List[Any]]
|
161
|
+
Stack of context-specific settings for nested contexts.
|
162
|
+
functions : Dict[Hashable, Callable]
|
163
|
+
Registered callback functions for environment changes.
|
164
|
+
locks : Dict[str, threading.Lock]
|
165
|
+
Thread locks for synchronized access to critical sections.
|
166
|
+
"""
|
167
|
+
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
168
|
+
contexts: defaultdict[Hashable, List[Any]] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
169
|
+
functions: Dict[Hashable, Callable] = dataclasses.field(default_factory=dict)
|
170
|
+
locks: Dict[str, threading.Lock] = dataclasses.field(default_factory=lambda: defaultdict(threading.Lock))
|
76
171
|
|
77
|
-
|
172
|
+
def __post_init__(self):
|
173
|
+
"""Initialize with default settings."""
|
174
|
+
# Set default precision if not already set
|
175
|
+
if PRECISION not in self.settings:
|
176
|
+
self.settings[PRECISION] = DEFAULT_PRECISION
|
78
177
|
|
79
|
-
>>> import brainstate as brainstate
|
80
|
-
>>> with brainstate.environ.context(dt=0.1) as env:
|
81
|
-
... dt = brainstate.environ.get('dt')
|
82
|
-
... print(env)
|
83
178
|
|
84
|
-
|
85
|
-
|
86
|
-
raise ValueError('\n'
|
87
|
-
'Cannot set platform in "context" environment. \n'
|
88
|
-
'You should set platform in the global environment by "set_platform()" or "set()".')
|
89
|
-
if 'host_device_count' in kwargs:
|
90
|
-
raise ValueError('Cannot set host_device_count in environment context. '
|
91
|
-
'Please use set_host_device_count() or set() for the global setting.')
|
179
|
+
# Global environment state
|
180
|
+
_ENV_STATE = EnvironmentState()
|
92
181
|
|
93
|
-
if 'precision' in kwargs:
|
94
|
-
last_precision = _get_precision()
|
95
|
-
_set_jax_precision(kwargs['precision'])
|
96
182
|
|
97
|
-
|
98
|
-
|
183
|
+
def reset() -> None:
|
184
|
+
"""
|
185
|
+
Reset the environment to default settings.
|
186
|
+
|
187
|
+
This function clears all custom settings and restores the environment
|
188
|
+
to its initial state. Useful for testing or when starting fresh.
|
189
|
+
|
190
|
+
Examples
|
191
|
+
--------
|
192
|
+
.. code-block:: python
|
193
|
+
|
194
|
+
>>> import brainstate.environ as env
|
195
|
+
>>>
|
196
|
+
>>> # Set custom values
|
197
|
+
>>> env.set(dt=0.1, custom_param='value')
|
198
|
+
>>> print(env.get('custom_param')) # 'value'
|
199
|
+
>>>
|
200
|
+
>>> # Reset to defaults
|
201
|
+
>>> env.reset()
|
202
|
+
>>> print(env.get('custom_param', default=None)) # None
|
203
|
+
|
204
|
+
Notes
|
205
|
+
-----
|
206
|
+
This operation cannot be undone. All custom settings will be lost.
|
207
|
+
"""
|
208
|
+
global _ENV_STATE
|
209
|
+
_ENV_STATE = EnvironmentState()
|
210
|
+
# Re-apply default precision
|
211
|
+
_set_jax_precision(DEFAULT_PRECISION)
|
99
212
|
|
100
|
-
|
101
|
-
|
213
|
+
warnings.warn(
|
214
|
+
"Environment has been reset to default settings. "
|
215
|
+
"All custom configurations have been cleared.",
|
216
|
+
UserWarning
|
217
|
+
)
|
102
218
|
|
103
|
-
# restore the environment functions
|
104
|
-
if k in DFAULT.functions:
|
105
|
-
DFAULT.functions[k](v)
|
106
219
|
|
107
|
-
|
108
|
-
|
109
|
-
|
220
|
+
@contextlib.contextmanager
|
221
|
+
def context(**kwargs) -> ContextManager[Dict[str, Any]]:
|
222
|
+
"""
|
223
|
+
Context manager for temporary environment settings.
|
110
224
|
|
111
|
-
|
225
|
+
This context manager allows you to temporarily modify environment settings
|
226
|
+
within a specific scope. Settings are automatically restored when exiting
|
227
|
+
the context, even if an exception occurs.
|
112
228
|
|
113
|
-
|
114
|
-
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
**kwargs
|
232
|
+
Environment settings to apply within the context.
|
233
|
+
Common parameters include:
|
234
|
+
|
235
|
+
- precision : int or str.
|
236
|
+
Numerical precision (8, 16, 32, 64, or 'bf16')
|
237
|
+
- dt : float.
|
238
|
+
Time step for numerical integration
|
239
|
+
- mode : :class:`Mode`.
|
240
|
+
Computation mode instance
|
241
|
+
- Any custom parameters registered via register_default_behavior
|
242
|
+
|
243
|
+
Yields
|
244
|
+
------
|
245
|
+
dict
|
246
|
+
Current environment settings within the context.
|
247
|
+
|
248
|
+
Raises
|
249
|
+
------
|
250
|
+
ValueError
|
251
|
+
If attempting to set platform or host_device_count in context
|
252
|
+
(these must be set globally).
|
253
|
+
TypeError
|
254
|
+
If invalid parameter types are provided.
|
255
|
+
|
256
|
+
Examples
|
257
|
+
--------
|
258
|
+
Basic usage with precision control:
|
259
|
+
|
260
|
+
.. code-block:: python
|
261
|
+
|
262
|
+
>>> import brainstate.environ as env
|
263
|
+
>>>
|
264
|
+
>>> # Set global precision
|
265
|
+
>>> env.set(precision=32)
|
266
|
+
>>>
|
267
|
+
>>> # Temporarily use higher precision
|
268
|
+
>>> with env.context(precision=64) as ctx:
|
269
|
+
... print(f"Precision in context: {env.get('precision')}") # 64
|
270
|
+
... print(f"Float type: {env.dftype()}") # float64
|
271
|
+
>>>
|
272
|
+
>>> print(f"Precision after context: {env.get('precision')}") # 32
|
273
|
+
|
274
|
+
Nested contexts:
|
275
|
+
|
276
|
+
.. code-block:: python
|
277
|
+
|
278
|
+
>>> import brainstate.environ as env
|
279
|
+
>>>
|
280
|
+
>>> with env.context(dt=0.1) as ctx1:
|
281
|
+
... print(f"dt = {env.get('dt')}") # 0.1
|
282
|
+
...
|
283
|
+
... with env.context(dt=0.01) as ctx2:
|
284
|
+
... print(f"dt = {env.get('dt')}") # 0.01
|
285
|
+
...
|
286
|
+
... print(f"dt = {env.get('dt')}") # 0.1
|
287
|
+
|
288
|
+
Error handling in context:
|
289
|
+
|
290
|
+
.. code-block:: python
|
291
|
+
|
292
|
+
>>> import brainstate.environ as env
|
293
|
+
>>>
|
294
|
+
>>> env.set(value=10)
|
295
|
+
>>> try:
|
296
|
+
... with env.context(value=20):
|
297
|
+
... print(env.get('value')) # 20
|
298
|
+
... raise ValueError("Something went wrong")
|
299
|
+
... except ValueError:
|
300
|
+
... pass
|
301
|
+
>>>
|
302
|
+
>>> print(env.get('value')) # 10 (restored)
|
303
|
+
|
304
|
+
Notes
|
305
|
+
-----
|
306
|
+
- Platform and host_device_count cannot be set in context
|
307
|
+
- Contexts can be nested arbitrarily deep
|
308
|
+
- Settings are restored in reverse order when exiting
|
309
|
+
- Thread-safe: each thread maintains its own context stack
|
310
|
+
"""
|
311
|
+
# Validate restricted parameters
|
312
|
+
if PLATFORM in kwargs:
|
313
|
+
raise ValueError(
|
314
|
+
f"Cannot set '{PLATFORM}' in context. "
|
315
|
+
f"Use set_platform() or set() for global configuration."
|
316
|
+
)
|
317
|
+
if HOST_DEVICE_COUNT in kwargs:
|
318
|
+
raise ValueError(
|
319
|
+
f"Cannot set '{HOST_DEVICE_COUNT}' in context. "
|
320
|
+
f"Use set_host_device_count() or set() for global configuration."
|
321
|
+
)
|
322
|
+
|
323
|
+
# Handle precision changes
|
324
|
+
original_precision = None
|
325
|
+
if PRECISION in kwargs:
|
326
|
+
original_precision = _get_precision()
|
327
|
+
_validate_precision(kwargs[PRECISION])
|
328
|
+
_set_jax_precision(kwargs[PRECISION])
|
115
329
|
|
116
|
-
|
117
|
-
|
118
|
-
|
330
|
+
try:
|
331
|
+
# Push new values onto context stacks
|
332
|
+
for key, value in kwargs.items():
|
333
|
+
with _ENV_STATE.locks[key]:
|
334
|
+
_ENV_STATE.contexts[key].append(value)
|
335
|
+
|
336
|
+
# Trigger registered callbacks
|
337
|
+
if key in _ENV_STATE.functions:
|
338
|
+
try:
|
339
|
+
_ENV_STATE.functions[key](value)
|
340
|
+
except Exception as e:
|
341
|
+
warnings.warn(
|
342
|
+
f"Callback for '{key}' raised an exception: {e}",
|
343
|
+
RuntimeWarning
|
344
|
+
)
|
345
|
+
|
346
|
+
# Yield current environment state
|
347
|
+
yield all()
|
119
348
|
|
120
|
-
|
121
|
-
|
349
|
+
finally:
|
350
|
+
# Restore previous values
|
351
|
+
for key in kwargs:
|
352
|
+
with _ENV_STATE.locks[key]:
|
353
|
+
if _ENV_STATE.contexts[key]:
|
354
|
+
_ENV_STATE.contexts[key].pop()
|
355
|
+
|
356
|
+
# Restore callbacks with previous value
|
357
|
+
if key in _ENV_STATE.functions:
|
358
|
+
try:
|
359
|
+
prev_value = get(key, default=None)
|
360
|
+
if prev_value is not None:
|
361
|
+
_ENV_STATE.functions[key](prev_value)
|
362
|
+
except Exception as e:
|
363
|
+
warnings.warn(
|
364
|
+
f"Callback restoration for '{key}' raised: {e}",
|
365
|
+
RuntimeWarning
|
366
|
+
)
|
367
|
+
|
368
|
+
# Restore precision if it was changed
|
369
|
+
if original_precision is not None:
|
370
|
+
_set_jax_precision(original_precision)
|
371
|
+
|
372
|
+
|
373
|
+
def get(key: str, default: Any = _NOT_PROVIDED, desc: Optional[str] = None) -> Any:
|
374
|
+
"""
|
375
|
+
Get a value from the current environment.
|
122
376
|
|
377
|
+
This function retrieves values from the environment, checking first in
|
378
|
+
the context stack, then in global settings. Special handling is provided
|
379
|
+
for platform and device count parameters.
|
123
380
|
|
124
|
-
|
125
|
-
|
126
|
-
|
381
|
+
Parameters
|
382
|
+
----------
|
383
|
+
key : str
|
384
|
+
The environment key to retrieve.
|
385
|
+
default : Any, optional
|
386
|
+
Default value to return if key is not found.
|
387
|
+
If not provided, raises KeyError for missing keys.
|
388
|
+
desc : str, optional
|
389
|
+
Description of the parameter for error messages.
|
127
390
|
|
128
391
|
Returns
|
129
392
|
-------
|
130
|
-
|
131
|
-
|
393
|
+
Any
|
394
|
+
The value associated with the key.
|
395
|
+
|
396
|
+
Raises
|
397
|
+
------
|
398
|
+
KeyError
|
399
|
+
If key is not found and no default is provided.
|
400
|
+
|
401
|
+
Examples
|
402
|
+
--------
|
403
|
+
Basic retrieval:
|
404
|
+
|
405
|
+
.. code-block:: python
|
406
|
+
|
407
|
+
>>> import brainstate.environ as env
|
408
|
+
>>>
|
409
|
+
>>> env.set(learning_rate=0.001)
|
410
|
+
>>> lr = env.get('learning_rate')
|
411
|
+
>>> print(f"Learning rate: {lr}") # 0.001
|
412
|
+
|
413
|
+
With default value:
|
414
|
+
|
415
|
+
.. code-block:: python
|
416
|
+
|
417
|
+
>>> import brainstate.environ as env
|
418
|
+
>>>
|
419
|
+
>>> # Get with default
|
420
|
+
>>> batch_size = env.get('batch_size', default=32)
|
421
|
+
>>> print(f"Batch size: {batch_size}") # 32
|
422
|
+
|
423
|
+
Context-aware retrieval:
|
424
|
+
|
425
|
+
.. code-block:: python
|
426
|
+
|
427
|
+
>>> import brainstate.environ as env
|
428
|
+
>>>
|
429
|
+
>>> env.set(temperature=1.0)
|
430
|
+
>>> print(env.get('temperature')) # 1.0
|
431
|
+
>>>
|
432
|
+
>>> with env.context(temperature=0.5):
|
433
|
+
... print(env.get('temperature')) # 0.5
|
434
|
+
>>>
|
435
|
+
>>> print(env.get('temperature')) # 1.0
|
436
|
+
|
437
|
+
Notes
|
438
|
+
-----
|
439
|
+
Special keys 'platform' and 'host_device_count' are handled separately
|
440
|
+
and retrieve system-level information.
|
132
441
|
"""
|
133
|
-
|
442
|
+
# Special cases for platform-specific parameters
|
443
|
+
if key == PLATFORM:
|
134
444
|
return get_platform()
|
135
|
-
|
136
|
-
if key == 'host_device_count':
|
445
|
+
if key == HOST_DEVICE_COUNT:
|
137
446
|
return get_host_device_count()
|
138
447
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
if
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
448
|
+
# Check context stack first (most recent value)
|
449
|
+
with _ENV_STATE.locks[key]:
|
450
|
+
if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
|
451
|
+
return _ENV_STATE.contexts[key][-1]
|
452
|
+
|
453
|
+
# Check global settings
|
454
|
+
if key in _ENV_STATE.settings:
|
455
|
+
return _ENV_STATE.settings[key]
|
456
|
+
|
457
|
+
# Handle missing key
|
458
|
+
if default is _NOT_PROVIDED:
|
459
|
+
error_msg = f"Key '{key}' not found in environment."
|
460
|
+
if desc:
|
461
|
+
error_msg += f" Description: {desc}"
|
462
|
+
error_msg += (
|
463
|
+
f"\nSet it using:\n"
|
464
|
+
f" - env.set({key}=value) for global setting\n"
|
465
|
+
f" - env.context({key}=value) for temporary setting"
|
466
|
+
)
|
467
|
+
raise KeyError(error_msg)
|
468
|
+
|
159
469
|
return default
|
160
470
|
|
161
471
|
|
162
|
-
def all() ->
|
472
|
+
def all() -> Dict[str, Any]:
|
163
473
|
"""
|
164
|
-
Get all
|
474
|
+
Get all current environment settings.
|
475
|
+
|
476
|
+
This function returns a dictionary containing all active environment
|
477
|
+
settings, with context values taking precedence over global settings.
|
165
478
|
|
166
479
|
Returns
|
167
480
|
-------
|
168
|
-
|
169
|
-
|
481
|
+
dict
|
482
|
+
Dictionary of all current environment settings.
|
483
|
+
|
484
|
+
Examples
|
485
|
+
--------
|
486
|
+
.. code-block:: python
|
487
|
+
|
488
|
+
>>> import brainstate.environ as env
|
489
|
+
>>>
|
490
|
+
>>> # Set various parameters
|
491
|
+
>>> env.set(precision=32, dt=0.01, debug=True)
|
492
|
+
>>>
|
493
|
+
>>> # Get all settings
|
494
|
+
>>> settings = env.all()
|
495
|
+
>>> print(settings)
|
496
|
+
{'precision': 32, 'dt': 0.01, 'debug': True}
|
497
|
+
|
498
|
+
>>> # Context overrides
|
499
|
+
>>> with env.context(precision=64, new_param='test'):
|
500
|
+
... settings = env.all()
|
501
|
+
... print(settings['precision']) # 64
|
502
|
+
... print(settings['new_param']) # 'test'
|
503
|
+
|
504
|
+
Notes
|
505
|
+
-----
|
506
|
+
The returned dictionary is a snapshot and modifying it does not
|
507
|
+
affect the environment settings.
|
508
|
+
"""
|
509
|
+
result = {}
|
510
|
+
|
511
|
+
# Add global settings
|
512
|
+
result.update(_ENV_STATE.settings)
|
513
|
+
|
514
|
+
# Override with context values (most recent)
|
515
|
+
for key, values in _ENV_STATE.contexts.items():
|
516
|
+
if values:
|
517
|
+
result[key] = values[-1]
|
518
|
+
|
519
|
+
return result
|
520
|
+
|
521
|
+
|
522
|
+
def pop(key: str, default: Any = _NOT_PROVIDED) -> Any:
|
170
523
|
"""
|
171
|
-
|
172
|
-
for k, v in DFAULT.contexts.items():
|
173
|
-
if v:
|
174
|
-
r[k] = v[-1]
|
175
|
-
for k, v in DFAULT.settings.items():
|
176
|
-
if k not in r:
|
177
|
-
r[k] = v
|
178
|
-
return r
|
524
|
+
Remove and return a value from the global environment.
|
179
525
|
|
526
|
+
This function removes a key from the global environment settings and
|
527
|
+
returns its value. If the key is not found, it returns the default
|
528
|
+
value if provided, or raises KeyError.
|
180
529
|
|
181
|
-
|
182
|
-
|
530
|
+
Note that this function only affects global settings, not context values.
|
531
|
+
Keys in active contexts are not affected.
|
532
|
+
|
533
|
+
Parameters
|
534
|
+
----------
|
535
|
+
key : str
|
536
|
+
The environment key to remove.
|
537
|
+
default : Any, optional
|
538
|
+
Default value to return if key is not found.
|
539
|
+
If not provided, raises KeyError for missing keys.
|
183
540
|
|
184
541
|
Returns
|
185
542
|
-------
|
186
|
-
|
187
|
-
|
543
|
+
Any
|
544
|
+
The value that was removed from the environment.
|
545
|
+
|
546
|
+
Raises
|
547
|
+
------
|
548
|
+
KeyError
|
549
|
+
If key is not found and no default is provided.
|
550
|
+
ValueError
|
551
|
+
If attempting to pop a key that is currently in a context.
|
552
|
+
|
553
|
+
Examples
|
554
|
+
--------
|
555
|
+
Basic usage:
|
556
|
+
|
557
|
+
.. code-block:: python
|
558
|
+
|
559
|
+
>>> import brainstate.environ as env
|
560
|
+
>>>
|
561
|
+
>>> # Set a value
|
562
|
+
>>> env.set(temp_param='temporary')
|
563
|
+
>>> print(env.get('temp_param')) # 'temporary'
|
564
|
+
>>>
|
565
|
+
>>> # Pop the value
|
566
|
+
>>> value = env.pop('temp_param')
|
567
|
+
>>> print(value) # 'temporary'
|
568
|
+
>>>
|
569
|
+
>>> # Value is now gone
|
570
|
+
>>> env.get('temp_param', default=None) # None
|
571
|
+
|
572
|
+
With default value:
|
573
|
+
|
574
|
+
.. code-block:: python
|
575
|
+
|
576
|
+
>>> import brainstate.environ as env
|
577
|
+
>>>
|
578
|
+
>>> # Pop non-existent key with default
|
579
|
+
>>> value = env.pop('nonexistent', default='default_value')
|
580
|
+
>>> print(value) # 'default_value'
|
581
|
+
|
582
|
+
Pop multiple values:
|
583
|
+
|
584
|
+
.. code-block:: python
|
585
|
+
|
586
|
+
>>> import brainstate.environ as env
|
587
|
+
>>>
|
588
|
+
>>> # Set multiple values
|
589
|
+
>>> env.set(param1='value1', param2='value2', param3='value3')
|
590
|
+
>>>
|
591
|
+
>>> # Pop them one by one
|
592
|
+
>>> v1 = env.pop('param1')
|
593
|
+
>>> v2 = env.pop('param2')
|
594
|
+
>>>
|
595
|
+
>>> # param3 still exists
|
596
|
+
>>> print(env.get('param3')) # 'value3'
|
597
|
+
|
598
|
+
Context protection:
|
599
|
+
|
600
|
+
.. code-block:: python
|
601
|
+
|
602
|
+
>>> import brainstate.environ as env
|
603
|
+
>>>
|
604
|
+
>>> env.set(protected='global_value')
|
605
|
+
>>>
|
606
|
+
>>> with env.context(protected='context_value'):
|
607
|
+
... # Cannot pop a key that's in active context
|
608
|
+
... try:
|
609
|
+
... env.pop('protected')
|
610
|
+
... except ValueError as e:
|
611
|
+
... print("Cannot pop key in active context")
|
612
|
+
|
613
|
+
Notes
|
614
|
+
-----
|
615
|
+
- This function only removes keys from global settings
|
616
|
+
- Keys that are currently overridden in active contexts cannot be popped
|
617
|
+
- Special keys like 'platform' and 'host_device_count' can be popped but
|
618
|
+
their system-level values remain accessible through get_platform() etc.
|
619
|
+
- Registered callbacks are NOT triggered when popping values
|
188
620
|
"""
|
189
|
-
|
621
|
+
# Check if key is currently in any active context
|
622
|
+
if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
|
623
|
+
raise ValueError(
|
624
|
+
f"Cannot pop key '{key}' while it is active in a context. "
|
625
|
+
f"The key is currently overridden in {len(_ENV_STATE.contexts[key])} context(s)."
|
626
|
+
)
|
627
|
+
|
628
|
+
# Check if key exists in global settings
|
629
|
+
if key in _ENV_STATE.settings:
|
630
|
+
# Remove and return the value
|
631
|
+
value = _ENV_STATE.settings.pop(key)
|
632
|
+
|
633
|
+
# Note: We don't trigger callbacks here as this is a removal operation
|
634
|
+
# If needed, users can register callbacks for removal separately
|
635
|
+
|
636
|
+
return value
|
637
|
+
|
638
|
+
# Key not found, handle default
|
639
|
+
if default is _NOT_PROVIDED:
|
640
|
+
raise KeyError(f"Key '{key}' not found in global environment settings.")
|
641
|
+
|
642
|
+
return default
|
643
|
+
|
190
644
|
|
645
|
+
def set(
|
646
|
+
platform: Optional[PlatformType] = None,
|
647
|
+
host_device_count: Optional[int] = None,
|
648
|
+
precision: Optional[PrecisionType] = None,
|
649
|
+
dt: Optional[float] = None,
|
650
|
+
**kwargs
|
651
|
+
) -> None:
|
652
|
+
"""
|
653
|
+
Set global environment configuration.
|
191
654
|
|
192
|
-
|
193
|
-
|
655
|
+
This function sets persistent global environment settings that remain
|
656
|
+
active until explicitly changed or the program terminates.
|
194
657
|
|
195
|
-
|
658
|
+
Parameters
|
196
659
|
----------
|
197
|
-
|
198
|
-
|
660
|
+
platform : str, optional
|
661
|
+
Computing platform ('cpu', 'gpu', or 'tpu').
|
662
|
+
host_device_count : int, optional
|
663
|
+
Number of host devices for parallel computation.
|
664
|
+
precision : int or str, optional
|
665
|
+
Numerical precision (8, 16, 32, 64, or 'bf16').
|
666
|
+
mode : Mode, optional
|
667
|
+
Computation mode instance.
|
668
|
+
dt : float, optional
|
669
|
+
Time step for numerical integration.
|
670
|
+
**kwargs
|
671
|
+
Additional custom environment parameters.
|
672
|
+
|
673
|
+
Raises
|
674
|
+
------
|
675
|
+
ValueError
|
676
|
+
If invalid platform or precision is specified.
|
677
|
+
TypeError
|
678
|
+
If mode is not a Mode instance.
|
679
|
+
|
680
|
+
Examples
|
681
|
+
--------
|
682
|
+
Basic configuration:
|
683
|
+
|
684
|
+
.. code-block:: python
|
685
|
+
|
686
|
+
>>> import brainstate as bs
|
687
|
+
>>> import brainstate.environ as env
|
688
|
+
>>>
|
689
|
+
>>> # Set multiple parameters
|
690
|
+
>>> env.set(
|
691
|
+
... precision=32,
|
692
|
+
... dt=0.01,
|
693
|
+
... mode=bs.mixin.Training(),
|
694
|
+
... debug=False
|
695
|
+
... )
|
696
|
+
>>>
|
697
|
+
>>> print(env.get('precision')) # 32
|
698
|
+
>>> print(env.get('dt')) # 0.01
|
699
|
+
|
700
|
+
Platform configuration:
|
701
|
+
|
702
|
+
.. code-block:: python
|
703
|
+
|
704
|
+
>>> import brainstate.environ as env
|
705
|
+
>>>
|
706
|
+
>>> # Configure for GPU computation
|
707
|
+
>>> env.set(platform='gpu', precision=16)
|
708
|
+
>>>
|
709
|
+
>>> # Configure for multi-core CPU
|
710
|
+
>>> env.set(platform='cpu', host_device_count=4)
|
711
|
+
|
712
|
+
Custom parameters:
|
713
|
+
|
714
|
+
.. code-block:: python
|
715
|
+
|
716
|
+
>>> import brainstate.environ as env
|
717
|
+
>>>
|
718
|
+
>>> # Set custom parameters
|
719
|
+
>>> env.set(
|
720
|
+
... experiment_name='test_001',
|
721
|
+
... random_seed=42,
|
722
|
+
... log_level='DEBUG'
|
723
|
+
... )
|
724
|
+
>>>
|
725
|
+
>>> # Retrieve custom parameters
|
726
|
+
>>> print(env.get('experiment_name')) # 'test_001'
|
727
|
+
|
728
|
+
Notes
|
729
|
+
-----
|
730
|
+
- Platform changes only take effect at program start
|
731
|
+
- Some JAX configurations require restart to take effect
|
732
|
+
- Custom parameters can be any hashable key-value pairs
|
199
733
|
"""
|
200
|
-
|
734
|
+
# Handle special parameters
|
735
|
+
if platform is not None:
|
736
|
+
set_platform(platform)
|
201
737
|
|
738
|
+
if host_device_count is not None:
|
739
|
+
set_host_device_count(host_device_count)
|
202
740
|
|
203
|
-
|
204
|
-
|
741
|
+
if precision is not None:
|
742
|
+
_validate_precision(precision)
|
743
|
+
_set_jax_precision(precision)
|
744
|
+
kwargs[PRECISION] = precision
|
205
745
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
"""
|
211
|
-
return devices()[0].platform
|
746
|
+
if dt is not None:
|
747
|
+
if not u.math.isscalar(dt):
|
748
|
+
raise TypeError(f"'{DT}' must be a scalar number, got {type(dt)}")
|
749
|
+
kwargs[DT] = dt
|
212
750
|
|
751
|
+
# Update global settings
|
752
|
+
_ENV_STATE.settings.update(kwargs)
|
213
753
|
|
214
|
-
|
754
|
+
# Trigger registered callbacks
|
755
|
+
for key, value in kwargs.items():
|
756
|
+
if key in _ENV_STATE.functions:
|
757
|
+
try:
|
758
|
+
_ENV_STATE.functions[key](value)
|
759
|
+
except Exception as e:
|
760
|
+
warnings.warn(
|
761
|
+
f"Callback for '{key}' raised an exception: {e}",
|
762
|
+
RuntimeWarning
|
763
|
+
)
|
764
|
+
|
765
|
+
|
766
|
+
def get_dt() -> float:
|
215
767
|
"""
|
216
|
-
Get the
|
768
|
+
Get the current numerical integration time step.
|
217
769
|
|
218
770
|
Returns
|
219
771
|
-------
|
220
|
-
|
221
|
-
|
772
|
+
float
|
773
|
+
The time step value.
|
774
|
+
|
775
|
+
Raises
|
776
|
+
------
|
777
|
+
KeyError
|
778
|
+
If dt is not set.
|
779
|
+
|
780
|
+
Examples
|
781
|
+
--------
|
782
|
+
.. code-block:: python
|
783
|
+
|
784
|
+
>>> import brainstate.environ as env
|
785
|
+
>>>
|
786
|
+
>>> env.set(dt=0.01)
|
787
|
+
>>> dt = env.get_dt()
|
788
|
+
>>> print(f"Time step: {dt} ms") # Time step: 0.01 ms
|
789
|
+
>>>
|
790
|
+
>>> # Use in computation
|
791
|
+
>>> with env.context(dt=0.001):
|
792
|
+
... fine_dt = env.get_dt()
|
793
|
+
... print(f"Fine time step: {fine_dt}") # 0.001
|
222
794
|
"""
|
223
|
-
|
224
|
-
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
225
|
-
return int(match.group(1)) if match else 1
|
795
|
+
return get(DT)
|
226
796
|
|
227
797
|
|
228
|
-
def
|
798
|
+
def get_platform() -> PlatformType:
|
229
799
|
"""
|
230
|
-
Get the
|
800
|
+
Get the current computing platform.
|
231
801
|
|
232
802
|
Returns
|
233
803
|
-------
|
234
|
-
|
235
|
-
|
804
|
+
str
|
805
|
+
Platform name ('cpu', 'gpu', or 'tpu').
|
806
|
+
|
807
|
+
Examples
|
808
|
+
--------
|
809
|
+
.. code-block:: python
|
810
|
+
|
811
|
+
>>> import brainstate.environ as env
|
812
|
+
>>>
|
813
|
+
>>> platform = env.get_platform()
|
814
|
+
>>> print(f"Running on: {platform}")
|
815
|
+
>>>
|
816
|
+
>>> if platform == 'gpu':
|
817
|
+
... print("GPU acceleration available")
|
818
|
+
... else:
|
819
|
+
... print(f"Using {platform.upper()}")
|
236
820
|
"""
|
237
|
-
return
|
821
|
+
return devices()[0].platform
|
238
822
|
|
239
823
|
|
240
|
-
def
|
824
|
+
def get_host_device_count() -> int:
|
241
825
|
"""
|
242
|
-
Get the
|
826
|
+
Get the number of host devices.
|
243
827
|
|
244
828
|
Returns
|
245
829
|
-------
|
246
|
-
|
247
|
-
|
830
|
+
int
|
831
|
+
Number of host devices configured.
|
832
|
+
|
833
|
+
Examples
|
834
|
+
--------
|
835
|
+
.. code-block:: python
|
836
|
+
|
837
|
+
>>> import brainstate.environ as env
|
838
|
+
>>>
|
839
|
+
>>> # Get device count
|
840
|
+
>>> n_devices = env.get_host_device_count()
|
841
|
+
>>> print(f"Host devices: {n_devices}")
|
842
|
+
>>>
|
843
|
+
>>> # Configure for parallel computation
|
844
|
+
>>> if n_devices > 1:
|
845
|
+
... print(f"Can use {n_devices} devices for parallel computation")
|
248
846
|
"""
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
if isinstance(precision, int):
|
253
|
-
return precision
|
254
|
-
if isinstance(precision, str):
|
255
|
-
return int(precision)
|
256
|
-
raise ValueError(f'Unsupported precision: {precision}')
|
847
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
848
|
+
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
849
|
+
return int(match.group(1)) if match else 1
|
257
850
|
|
258
851
|
|
259
|
-
def
|
260
|
-
platform: str = None,
|
261
|
-
host_device_count: int = None,
|
262
|
-
precision: int | str = None,
|
263
|
-
mode: Mode = None,
|
264
|
-
**kwargs
|
265
|
-
):
|
852
|
+
def set_platform(platform: PlatformType) -> None:
|
266
853
|
"""
|
267
|
-
Set the
|
268
|
-
|
854
|
+
Set the computing platform.
|
269
855
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
856
|
+
Parameters
|
857
|
+
----------
|
858
|
+
platform : str
|
859
|
+
Platform to use ('cpu', 'gpu', or 'tpu').
|
860
|
+
|
861
|
+
Raises
|
862
|
+
------
|
863
|
+
ValueError
|
864
|
+
If platform is not supported.
|
865
|
+
|
866
|
+
Examples
|
867
|
+
--------
|
868
|
+
.. code-block:: python
|
869
|
+
|
870
|
+
>>> import brainstate.environ as env
|
871
|
+
>>>
|
872
|
+
>>> # Set to GPU
|
873
|
+
>>> env.set_platform('gpu')
|
874
|
+
>>>
|
875
|
+
>>> # Verify platform
|
876
|
+
>>> print(env.get_platform()) # 'gpu'
|
877
|
+
|
878
|
+
Notes
|
879
|
+
-----
|
880
|
+
Platform changes only take effect at program start. Changing platform
|
881
|
+
after JAX initialization may not have the expected effect.
|
277
882
|
"""
|
278
|
-
if platform
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
if precision is not None:
|
283
|
-
_set_jax_precision(precision)
|
284
|
-
kwargs['precision'] = precision
|
285
|
-
if mode is not None:
|
286
|
-
assert isinstance(mode, Mode), 'mode must be a Mode instance.'
|
287
|
-
kwargs['mode'] = mode
|
883
|
+
if platform not in SUPPORTED_PLATFORMS:
|
884
|
+
raise ValueError(
|
885
|
+
f"Platform must be one of {SUPPORTED_PLATFORMS}, got '{platform}'"
|
886
|
+
)
|
288
887
|
|
289
|
-
|
290
|
-
DFAULT.settings.update(kwargs)
|
888
|
+
config.update("jax_platform_name", platform)
|
291
889
|
|
292
|
-
#
|
293
|
-
|
294
|
-
|
295
|
-
DFAULT.functions[k](v)
|
890
|
+
# Trigger callbacks
|
891
|
+
if PLATFORM in _ENV_STATE.functions:
|
892
|
+
_ENV_STATE.functions[PLATFORM](platform)
|
296
893
|
|
297
894
|
|
298
|
-
def set_host_device_count(n):
|
895
|
+
def set_host_device_count(n: int) -> None:
|
299
896
|
"""
|
300
|
-
|
301
|
-
that there are `n` host (CPU) devices available to use. As a consequence, this
|
302
|
-
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
897
|
+
Set the number of host (CPU) devices.
|
303
898
|
|
304
|
-
|
305
|
-
|
306
|
-
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
307
|
-
`[num_device]` is the desired number of CPU devices `n`.
|
899
|
+
This function configures XLA to treat CPU cores as separate devices,
|
900
|
+
enabling parallel computation with jax.pmap on CPU.
|
308
901
|
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
902
|
+
Parameters
|
903
|
+
----------
|
904
|
+
n : int
|
905
|
+
Number of host devices to configure.
|
906
|
+
|
907
|
+
Raises
|
908
|
+
------
|
909
|
+
ValueError
|
910
|
+
If n is not a positive integer.
|
911
|
+
|
912
|
+
Examples
|
913
|
+
--------
|
914
|
+
.. code-block:: python
|
915
|
+
|
916
|
+
>>> import brainstate.environ as env
|
917
|
+
>>> import jax
|
918
|
+
>>>
|
919
|
+
>>> # Configure 4 CPU devices
|
920
|
+
>>> env.set_host_device_count(4)
|
921
|
+
>>>
|
922
|
+
>>> # Use with pmap
|
923
|
+
>>> def parallel_fn(x):
|
924
|
+
... return x * 2
|
925
|
+
>>>
|
926
|
+
>>> # This will work with 4 devices
|
927
|
+
>>> pmapped_fn = jax.pmap(parallel_fn)
|
928
|
+
|
929
|
+
Warnings
|
930
|
+
--------
|
931
|
+
This setting only takes effect at program start. The effects of using
|
932
|
+
xla_force_host_platform_device_count are not fully understood and may
|
933
|
+
cause unexpected behavior.
|
316
934
|
"""
|
935
|
+
if not isinstance(n, int) or n < 1:
|
936
|
+
raise ValueError(f"Host device count must be a positive integer, got {n}")
|
937
|
+
|
938
|
+
# Update XLA flags
|
317
939
|
xla_flags = os.getenv("XLA_FLAGS", "")
|
318
|
-
xla_flags = re.sub(
|
319
|
-
|
940
|
+
xla_flags = re.sub(
|
941
|
+
r"--xla_force_host_platform_device_count=\S+",
|
942
|
+
"",
|
943
|
+
xla_flags
|
944
|
+
).split()
|
320
945
|
|
321
|
-
|
322
|
-
|
323
|
-
|
946
|
+
os.environ["XLA_FLAGS"] = " ".join(
|
947
|
+
[f"--xla_force_host_platform_device_count={n}"] + xla_flags
|
948
|
+
)
|
324
949
|
|
950
|
+
# Trigger callbacks
|
951
|
+
if HOST_DEVICE_COUNT in _ENV_STATE.functions:
|
952
|
+
_ENV_STATE.functions[HOST_DEVICE_COUNT](n)
|
325
953
|
|
326
|
-
def set_platform(platform: str):
|
327
|
-
"""
|
328
|
-
Changes platform to CPU, GPU, or TPU. This utility only takes
|
329
|
-
effect at the beginning of your program.
|
330
954
|
|
331
|
-
|
332
|
-
|
955
|
+
def set_precision(precision: PrecisionType) -> None:
|
956
|
+
"""
|
957
|
+
Set the global numerical precision.
|
333
958
|
|
334
|
-
|
335
|
-
|
959
|
+
Parameters
|
960
|
+
----------
|
961
|
+
precision : int or str
|
962
|
+
Precision to use (8, 16, 32, 64, or 'bf16').
|
963
|
+
|
964
|
+
Raises
|
965
|
+
------
|
966
|
+
ValueError
|
967
|
+
If precision is not supported.
|
968
|
+
|
969
|
+
Examples
|
970
|
+
--------
|
971
|
+
.. code-block:: python
|
972
|
+
|
973
|
+
>>> import brainstate.environ as env
|
974
|
+
>>> import jax.numpy as jnp
|
975
|
+
>>>
|
976
|
+
>>> # Set to 64-bit precision
|
977
|
+
>>> env.set_precision(64)
|
978
|
+
>>>
|
979
|
+
>>> # Arrays will use float64 by default
|
980
|
+
>>> x = jnp.array([1.0, 2.0, 3.0])
|
981
|
+
>>> print(x.dtype) # float64
|
982
|
+
>>>
|
983
|
+
>>> # Set to bfloat16 for efficiency
|
984
|
+
>>> env.set_precision('bf16')
|
336
985
|
"""
|
337
|
-
|
338
|
-
|
986
|
+
_validate_precision(precision)
|
987
|
+
_set_jax_precision(precision)
|
988
|
+
_ENV_STATE.settings[PRECISION] = precision
|
339
989
|
|
340
|
-
#
|
341
|
-
if
|
342
|
-
|
990
|
+
# Trigger callbacks
|
991
|
+
if PRECISION in _ENV_STATE.functions:
|
992
|
+
_ENV_STATE.functions[PRECISION](precision)
|
343
993
|
|
344
994
|
|
345
|
-
def
|
995
|
+
def get_precision() -> int:
|
346
996
|
"""
|
347
|
-
|
997
|
+
Get the current numerical precision as an integer.
|
348
998
|
|
349
|
-
|
350
|
-
|
999
|
+
Returns
|
1000
|
+
-------
|
1001
|
+
int
|
1002
|
+
Precision in bits (8, 16, 32, or 64).
|
1003
|
+
|
1004
|
+
Examples
|
1005
|
+
--------
|
1006
|
+
.. code-block:: python
|
1007
|
+
|
1008
|
+
>>> import brainstate.environ as env
|
1009
|
+
>>>
|
1010
|
+
>>> env.set_precision(32)
|
1011
|
+
>>> bits = env.get_precision()
|
1012
|
+
>>> print(f"Using {bits}-bit precision") # Using 32-bit precision
|
1013
|
+
>>>
|
1014
|
+
>>> # Special handling for bfloat16
|
1015
|
+
>>> env.set_precision('bf16')
|
1016
|
+
>>> print(env.get_precision()) # 16
|
1017
|
+
|
1018
|
+
Notes
|
1019
|
+
-----
|
1020
|
+
'bf16' (bfloat16) is reported as 16-bit precision.
|
351
1021
|
"""
|
352
|
-
|
353
|
-
|
1022
|
+
precision = get(PRECISION, default=DEFAULT_PRECISION)
|
1023
|
+
|
1024
|
+
if precision == 'bf16':
|
1025
|
+
return 16
|
1026
|
+
elif isinstance(precision, str):
|
1027
|
+
return int(precision)
|
1028
|
+
elif isinstance(precision, int):
|
1029
|
+
return precision
|
1030
|
+
else:
|
1031
|
+
raise ValueError(f"Invalid precision type: {type(precision)}")
|
1032
|
+
|
1033
|
+
|
1034
|
+
def _validate_precision(precision: PrecisionType) -> None:
|
1035
|
+
"""Validate precision value."""
|
1036
|
+
if precision not in SUPPORTED_PRECISIONS and str(precision) not in map(str, SUPPORTED_PRECISIONS):
|
1037
|
+
raise ValueError(
|
1038
|
+
f"Precision must be one of {SUPPORTED_PRECISIONS}, got {precision}"
|
1039
|
+
)
|
1040
|
+
|
1041
|
+
|
1042
|
+
def _get_precision() -> PrecisionType:
|
1043
|
+
"""Get raw precision value (including 'bf16')."""
|
1044
|
+
return get(PRECISION, default=DEFAULT_PRECISION)
|
1045
|
+
|
1046
|
+
|
1047
|
+
def _set_jax_precision(precision: PrecisionType) -> None:
|
1048
|
+
"""Configure JAX precision settings."""
|
1049
|
+
# Enable/disable 64-bit mode
|
1050
|
+
if precision in (64, '64'):
|
354
1051
|
config.update("jax_enable_x64", True)
|
355
1052
|
else:
|
356
1053
|
config.update("jax_enable_x64", False)
|
357
1054
|
|
358
1055
|
|
359
|
-
@functools.lru_cache()
|
360
|
-
def _get_uint(precision:
|
361
|
-
|
1056
|
+
@functools.lru_cache(maxsize=16)
|
1057
|
+
def _get_uint(precision: PrecisionType) -> DTypeLike:
|
1058
|
+
"""Get unsigned integer type for given precision."""
|
1059
|
+
if precision in (64, '64'):
|
362
1060
|
return np.uint64
|
363
|
-
elif precision in
|
1061
|
+
elif precision in (32, '32'):
|
364
1062
|
return np.uint32
|
365
|
-
elif precision in
|
1063
|
+
elif precision in (16, '16', 'bf16'):
|
366
1064
|
return np.uint16
|
367
|
-
elif precision in
|
1065
|
+
elif precision in (8, '8'):
|
368
1066
|
return np.uint8
|
369
1067
|
else:
|
370
|
-
raise ValueError(f
|
1068
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
371
1069
|
|
372
1070
|
|
373
|
-
@functools.lru_cache()
|
374
|
-
def _get_int(precision:
|
375
|
-
|
1071
|
+
@functools.lru_cache(maxsize=16)
|
1072
|
+
def _get_int(precision: PrecisionType) -> DTypeLike:
|
1073
|
+
"""Get integer type for given precision."""
|
1074
|
+
if precision in (64, '64'):
|
376
1075
|
return np.int64
|
377
|
-
elif precision in
|
1076
|
+
elif precision in (32, '32'):
|
378
1077
|
return np.int32
|
379
|
-
elif precision in
|
1078
|
+
elif precision in (16, '16', 'bf16'):
|
380
1079
|
return np.int16
|
381
|
-
elif precision in
|
1080
|
+
elif precision in (8, '8'):
|
382
1081
|
return np.int8
|
383
1082
|
else:
|
384
|
-
raise ValueError(f
|
1083
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
385
1084
|
|
386
1085
|
|
387
|
-
@functools.lru_cache()
|
388
|
-
def _get_float(precision:
|
389
|
-
|
1086
|
+
@functools.lru_cache(maxsize=16)
|
1087
|
+
def _get_float(precision: PrecisionType) -> DTypeLike:
|
1088
|
+
"""Get floating-point type for given precision."""
|
1089
|
+
if precision in (64, '64'):
|
390
1090
|
return np.float64
|
391
|
-
elif precision in
|
1091
|
+
elif precision in (32, '32'):
|
392
1092
|
return np.float32
|
393
|
-
elif precision in
|
1093
|
+
elif precision in (16, '16'):
|
394
1094
|
return np.float16
|
395
|
-
elif precision
|
1095
|
+
elif precision == 'bf16':
|
396
1096
|
return jnp.bfloat16
|
397
|
-
elif precision in
|
1097
|
+
elif precision in (8, '8'):
|
398
1098
|
return jnp.float8_e5m2
|
399
1099
|
else:
|
400
|
-
raise ValueError(f
|
1100
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
401
1101
|
|
402
1102
|
|
403
|
-
@functools.lru_cache()
|
404
|
-
def _get_complex(precision:
|
405
|
-
|
1103
|
+
@functools.lru_cache(maxsize=16)
|
1104
|
+
def _get_complex(precision: PrecisionType) -> DTypeLike:
|
1105
|
+
"""Get complex type for given precision."""
|
1106
|
+
if precision in (64, '64'):
|
406
1107
|
return np.complex128
|
407
|
-
elif precision in
|
408
|
-
return np.complex64
|
409
|
-
elif precision in [16, '16', 'bf16']:
|
410
|
-
return np.complex64
|
411
|
-
elif precision in [8, '8']:
|
1108
|
+
elif precision in (32, '32', 16, '16', 'bf16', 8, '8'):
|
412
1109
|
return np.complex64
|
413
1110
|
else:
|
414
|
-
raise ValueError(f
|
1111
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
415
1112
|
|
416
1113
|
|
417
1114
|
def dftype() -> DTypeLike:
|
418
1115
|
"""
|
419
|
-
|
1116
|
+
Get the default floating-point data type.
|
420
1117
|
|
421
|
-
This function returns the
|
422
|
-
|
423
|
-
you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
|
424
|
-
|
425
|
-
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
426
|
-
|
427
|
-
>>> import brainstate as brainstate
|
428
|
-
>>> import numpy as np
|
429
|
-
>>> with brainstate.environ.context(precision=32):
|
430
|
-
... a = np.zeros(1, dtype=brainstate.environ.dftype())
|
431
|
-
>>> print(a.dtype)
|
1118
|
+
This function returns the appropriate floating-point type based on
|
1119
|
+
the current precision setting, allowing dynamic type selection.
|
432
1120
|
|
433
1121
|
Returns
|
434
1122
|
-------
|
435
|
-
|
436
|
-
|
1123
|
+
DTypeLike
|
1124
|
+
Default floating-point data type.
|
1125
|
+
|
1126
|
+
Examples
|
1127
|
+
--------
|
1128
|
+
.. code-block:: python
|
1129
|
+
|
1130
|
+
>>> import brainstate.environ as env
|
1131
|
+
>>> import jax.numpy as jnp
|
1132
|
+
>>>
|
1133
|
+
>>> # With 32-bit precision
|
1134
|
+
>>> env.set(precision=32)
|
1135
|
+
>>> x = jnp.zeros(10, dtype=env.dftype())
|
1136
|
+
>>> print(x.dtype) # float32
|
1137
|
+
>>>
|
1138
|
+
>>> # With 64-bit precision
|
1139
|
+
>>> with env.context(precision=64):
|
1140
|
+
... y = jnp.ones(5, dtype=env.dftype())
|
1141
|
+
... print(y.dtype) # float64
|
1142
|
+
>>>
|
1143
|
+
>>> # With bfloat16
|
1144
|
+
>>> env.set(precision='bf16')
|
1145
|
+
>>> z = jnp.array([1, 2, 3], dtype=env.dftype())
|
1146
|
+
>>> print(z.dtype) # bfloat16
|
1147
|
+
|
1148
|
+
See Also
|
1149
|
+
--------
|
1150
|
+
ditype : Default integer type
|
1151
|
+
dutype : Default unsigned integer type
|
1152
|
+
dctype : Default complex type
|
437
1153
|
"""
|
438
1154
|
return _get_float(_get_precision())
|
439
1155
|
|
440
1156
|
|
441
1157
|
def ditype() -> DTypeLike:
|
442
1158
|
"""
|
443
|
-
|
444
|
-
|
445
|
-
This function returns the default integer data type based on the current precision.
|
446
|
-
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
447
|
-
you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
|
448
|
-
|
449
|
-
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
1159
|
+
Get the default integer data type.
|
450
1160
|
|
451
|
-
|
452
|
-
|
453
|
-
>>> with brainstate.environ.context(precision=32):
|
454
|
-
... a = np.zeros(1, dtype=brainstate.environ.ditype())
|
455
|
-
>>> print(a.dtype)
|
456
|
-
int32
|
1161
|
+
This function returns the appropriate integer type based on
|
1162
|
+
the current precision setting.
|
457
1163
|
|
458
1164
|
Returns
|
459
1165
|
-------
|
460
|
-
|
461
|
-
|
1166
|
+
DTypeLike
|
1167
|
+
Default integer data type.
|
1168
|
+
|
1169
|
+
Examples
|
1170
|
+
--------
|
1171
|
+
.. code-block:: python
|
1172
|
+
|
1173
|
+
>>> import brainstate.environ as env
|
1174
|
+
>>> import jax.numpy as jnp
|
1175
|
+
>>>
|
1176
|
+
>>> # With 32-bit precision
|
1177
|
+
>>> env.set(precision=32)
|
1178
|
+
>>> indices = jnp.arange(10, dtype=env.ditype())
|
1179
|
+
>>> print(indices.dtype) # int32
|
1180
|
+
>>>
|
1181
|
+
>>> # With 64-bit precision
|
1182
|
+
>>> with env.context(precision=64):
|
1183
|
+
... big_indices = jnp.arange(1000, dtype=env.ditype())
|
1184
|
+
... print(big_indices.dtype) # int64
|
1185
|
+
|
1186
|
+
See Also
|
1187
|
+
--------
|
1188
|
+
dftype : Default floating-point type
|
1189
|
+
dutype : Default unsigned integer type
|
462
1190
|
"""
|
463
1191
|
return _get_int(_get_precision())
|
464
1192
|
|
465
1193
|
|
466
1194
|
def dutype() -> DTypeLike:
|
467
1195
|
"""
|
468
|
-
|
1196
|
+
Get the default unsigned integer data type.
|
469
1197
|
|
470
|
-
This function returns the
|
471
|
-
|
472
|
-
by ``brainstate.environ.set(precision)``, you can use this function to get the default
|
473
|
-
unsigned integer data type, and create the data by using ``dtype=dutype()``.
|
474
|
-
|
475
|
-
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
476
|
-
|
477
|
-
>>> import brainstate as brainstate
|
478
|
-
>>> import numpy as np
|
479
|
-
>>> with brainstate.environ.context(precision=32):
|
480
|
-
... a = np.zeros(1, dtype=brainstate.environ.dutype())
|
481
|
-
>>> print(a.dtype)
|
482
|
-
uint32
|
1198
|
+
This function returns the appropriate unsigned integer type based on
|
1199
|
+
the current precision setting.
|
483
1200
|
|
484
1201
|
Returns
|
485
1202
|
-------
|
486
|
-
|
487
|
-
|
1203
|
+
DTypeLike
|
1204
|
+
Default unsigned integer data type.
|
1205
|
+
|
1206
|
+
Examples
|
1207
|
+
--------
|
1208
|
+
.. code-block:: python
|
1209
|
+
|
1210
|
+
>>> import brainstate.environ as env
|
1211
|
+
>>> import jax.numpy as jnp
|
1212
|
+
>>>
|
1213
|
+
>>> # With 32-bit precision
|
1214
|
+
>>> env.set(precision=32)
|
1215
|
+
>>> counts = jnp.array([10, 20, 30], dtype=env.dutype())
|
1216
|
+
>>> print(counts.dtype) # uint32
|
1217
|
+
>>>
|
1218
|
+
>>> # With 16-bit precision
|
1219
|
+
>>> with env.context(precision=16):
|
1220
|
+
... small_counts = jnp.array([1, 2, 3], dtype=env.dutype())
|
1221
|
+
... print(small_counts.dtype) # uint16
|
1222
|
+
|
1223
|
+
See Also
|
1224
|
+
--------
|
1225
|
+
ditype : Default signed integer type
|
488
1226
|
"""
|
489
1227
|
return _get_uint(_get_precision())
|
490
1228
|
|
491
1229
|
|
492
1230
|
def dctype() -> DTypeLike:
|
493
1231
|
"""
|
494
|
-
|
495
|
-
|
496
|
-
This function returns the default complex data type based on the current precision.
|
497
|
-
If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
|
498
|
-
you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
|
1232
|
+
Get the default complex data type.
|
499
1233
|
|
500
|
-
|
501
|
-
|
502
|
-
>>> import brainstate as brainstate
|
503
|
-
>>> import numpy as np
|
504
|
-
>>> with brainstate.environ.context(precision=32):
|
505
|
-
... a = np.zeros(1, dtype=brainstate.environ.dctype())
|
506
|
-
>>> print(a.dtype)
|
507
|
-
complex64
|
1234
|
+
This function returns the appropriate complex type based on
|
1235
|
+
the current precision setting.
|
508
1236
|
|
509
1237
|
Returns
|
510
1238
|
-------
|
511
|
-
|
512
|
-
|
1239
|
+
DTypeLike
|
1240
|
+
Default complex data type.
|
1241
|
+
|
1242
|
+
Examples
|
1243
|
+
--------
|
1244
|
+
.. code-block:: python
|
1245
|
+
|
1246
|
+
>>> import brainstate.environ as env
|
1247
|
+
>>> import jax.numpy as jnp
|
1248
|
+
>>>
|
1249
|
+
>>> # With 32-bit precision
|
1250
|
+
>>> env.set(precision=32)
|
1251
|
+
>>> z = jnp.array([1+2j, 3+4j], dtype=env.dctype())
|
1252
|
+
>>> print(z.dtype) # complex64
|
1253
|
+
>>>
|
1254
|
+
>>> # With 64-bit precision
|
1255
|
+
>>> with env.context(precision=64):
|
1256
|
+
... w = jnp.array([5+6j], dtype=env.dctype())
|
1257
|
+
... print(w.dtype) # complex128
|
1258
|
+
|
1259
|
+
Notes
|
1260
|
+
-----
|
1261
|
+
Complex128 is only available with 64-bit precision.
|
1262
|
+
All other precisions use complex64.
|
513
1263
|
"""
|
514
1264
|
return _get_complex(_get_precision())
|
515
1265
|
|
516
1266
|
|
517
|
-
def tolerance():
|
518
|
-
|
1267
|
+
def tolerance() -> jnp.ndarray:
|
1268
|
+
"""
|
1269
|
+
Get numerical tolerance based on current precision.
|
1270
|
+
|
1271
|
+
This function returns an appropriate tolerance value for numerical
|
1272
|
+
comparisons based on the current precision setting.
|
1273
|
+
|
1274
|
+
Returns
|
1275
|
+
-------
|
1276
|
+
jnp.ndarray
|
1277
|
+
Tolerance value as a scalar array.
|
1278
|
+
|
1279
|
+
Examples
|
1280
|
+
--------
|
1281
|
+
.. code-block:: python
|
1282
|
+
|
1283
|
+
>>> import brainstate.environ as env
|
1284
|
+
>>> import jax.numpy as jnp
|
1285
|
+
>>>
|
1286
|
+
>>> # Different tolerances for different precisions
|
1287
|
+
>>> env.set(precision=64)
|
1288
|
+
>>> tol64 = env.tolerance()
|
1289
|
+
>>> print(f"64-bit tolerance: {tol64}") # 1e-12
|
1290
|
+
>>>
|
1291
|
+
>>> env.set(precision=32)
|
1292
|
+
>>> tol32 = env.tolerance()
|
1293
|
+
>>> print(f"32-bit tolerance: {tol32}") # 1e-5
|
1294
|
+
>>>
|
1295
|
+
>>> # Use in numerical comparisons
|
1296
|
+
>>> def are_close(a, b):
|
1297
|
+
... return jnp.abs(a - b) < env.tolerance()
|
1298
|
+
|
1299
|
+
Notes
|
1300
|
+
-----
|
1301
|
+
Tolerance values:
|
1302
|
+
- 64-bit: 1e-12
|
1303
|
+
- 32-bit: 1e-5
|
1304
|
+
- 16-bit and below: 1e-2
|
1305
|
+
"""
|
1306
|
+
precision = get_precision()
|
1307
|
+
|
1308
|
+
if precision == 64:
|
519
1309
|
return jnp.array(1e-12, dtype=np.float64)
|
520
|
-
elif
|
1310
|
+
elif precision == 32:
|
521
1311
|
return jnp.array(1e-5, dtype=np.float32)
|
522
1312
|
else:
|
523
1313
|
return jnp.array(1e-2, dtype=np.float16)
|
524
1314
|
|
525
1315
|
|
526
|
-
def register_default_behavior(
|
1316
|
+
def register_default_behavior(
|
1317
|
+
key: str,
|
1318
|
+
behavior: Callable[[Any], None],
|
1319
|
+
replace_if_exist: bool = False
|
1320
|
+
) -> None:
|
527
1321
|
"""
|
528
|
-
Register a
|
1322
|
+
Register a callback for environment parameter changes.
|
529
1323
|
|
530
|
-
|
1324
|
+
This function allows you to register custom behaviors that are
|
1325
|
+
triggered whenever a specific environment parameter is modified.
|
531
1326
|
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
1327
|
+
Parameters
|
1328
|
+
----------
|
1329
|
+
key : str
|
1330
|
+
Environment parameter key to monitor.
|
1331
|
+
behavior : Callable[[Any], None]
|
1332
|
+
Callback function that receives the new value.
|
1333
|
+
replace_if_exist : bool, default=False
|
1334
|
+
Whether to replace existing callback for this key.
|
1335
|
+
|
1336
|
+
Raises
|
1337
|
+
------
|
1338
|
+
TypeError
|
1339
|
+
If behavior is not callable.
|
1340
|
+
ValueError
|
1341
|
+
If key already has a registered behavior and replace_if_exist is False.
|
1342
|
+
|
1343
|
+
Examples
|
1344
|
+
--------
|
1345
|
+
Basic callback registration:
|
1346
|
+
|
1347
|
+
.. code-block:: python
|
1348
|
+
|
1349
|
+
>>> import brainstate.environ as env
|
1350
|
+
>>>
|
1351
|
+
>>> # Define a callback
|
1352
|
+
>>> def on_dt_change(new_dt):
|
1353
|
+
... print(f"Time step changed to: {new_dt}")
|
1354
|
+
>>>
|
1355
|
+
>>> # Register the callback
|
1356
|
+
>>> env.register_default_behavior('dt', on_dt_change)
|
1357
|
+
>>>
|
1358
|
+
>>> # Callback is triggered on changes
|
1359
|
+
>>> env.set(dt=0.01) # Prints: Time step changed to: 0.01
|
1360
|
+
>>>
|
1361
|
+
>>> with env.context(dt=0.001): # Prints: Time step changed to: 0.001
|
1362
|
+
... pass # Prints: Time step changed to: 0.01 (on exit)
|
1363
|
+
|
1364
|
+
Complex behavior with validation:
|
1365
|
+
|
1366
|
+
.. code-block:: python
|
1367
|
+
|
1368
|
+
>>> import brainstate.environ as env
|
1369
|
+
>>>
|
1370
|
+
>>> def validate_batch_size(size):
|
1371
|
+
... if not isinstance(size, int) or size <= 0:
|
1372
|
+
... raise ValueError(f"Invalid batch size: {size}")
|
1373
|
+
... if size > 1024:
|
1374
|
+
... print(f"Warning: Large batch size {size} may cause OOM")
|
1375
|
+
>>>
|
1376
|
+
>>> env.register_default_behavior('batch_size', validate_batch_size)
|
1377
|
+
>>>
|
1378
|
+
>>> # Valid setting
|
1379
|
+
>>> env.set(batch_size=32) # OK
|
1380
|
+
>>>
|
1381
|
+
>>> # Invalid setting
|
1382
|
+
>>> # env.set(batch_size=-1) # Raises ValueError
|
1383
|
+
|
1384
|
+
Replacing existing behavior:
|
1385
|
+
|
1386
|
+
.. code-block:: python
|
1387
|
+
|
1388
|
+
>>> import brainstate.environ as env
|
1389
|
+
>>>
|
1390
|
+
>>> def old_behavior(value):
|
1391
|
+
... print(f"Old: {value}")
|
1392
|
+
>>>
|
1393
|
+
>>> def new_behavior(value):
|
1394
|
+
... print(f"New: {value}")
|
1395
|
+
>>>
|
1396
|
+
>>> env.register_default_behavior('key', old_behavior)
|
1397
|
+
>>> env.register_default_behavior('key', new_behavior, replace_if_exist=True)
|
1398
|
+
>>>
|
1399
|
+
>>> env.set(key='test') # Prints: New: test
|
1400
|
+
|
1401
|
+
See Also
|
1402
|
+
--------
|
1403
|
+
unregister_default_behavior : Remove registered callbacks
|
1404
|
+
list_registered_behaviors : List all registered callbacks
|
1405
|
+
"""
|
1406
|
+
if not isinstance(key, str):
|
1407
|
+
raise TypeError(f"Key must be a string, got {type(key)}")
|
537
1408
|
|
538
|
-
|
539
|
-
|
540
|
-
`dt_behavior(0.1)`.
|
1409
|
+
if not callable(behavior):
|
1410
|
+
raise TypeError(f"Behavior must be callable, got {type(behavior)}")
|
541
1411
|
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
Set the default dt to 0.1.
|
1412
|
+
if key in _ENV_STATE.functions and not replace_if_exist:
|
1413
|
+
raise ValueError(
|
1414
|
+
f"Behavior for key '{key}' already registered. "
|
1415
|
+
f"Use replace_if_exist=True to override."
|
1416
|
+
)
|
548
1417
|
|
1418
|
+
_ENV_STATE.functions[key] = behavior
|
549
1419
|
|
550
|
-
Args:
|
551
|
-
key: str. The key to register.
|
552
|
-
behavior: Callable. The behavior to register. It should be a callable.
|
553
|
-
replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
|
554
1420
|
|
1421
|
+
def unregister_default_behavior(key: str) -> bool:
|
1422
|
+
"""
|
1423
|
+
Remove a registered callback for an environment parameter.
|
1424
|
+
|
1425
|
+
Parameters
|
1426
|
+
----------
|
1427
|
+
key : str
|
1428
|
+
Environment parameter key.
|
1429
|
+
|
1430
|
+
Returns
|
1431
|
+
-------
|
1432
|
+
bool
|
1433
|
+
True if a callback was removed, False if no callback existed.
|
1434
|
+
|
1435
|
+
Examples
|
1436
|
+
--------
|
1437
|
+
.. code-block:: python
|
1438
|
+
|
1439
|
+
>>> import brainstate.environ as env
|
1440
|
+
>>>
|
1441
|
+
>>> # Register a callback
|
1442
|
+
>>> def callback(value):
|
1443
|
+
... print(f"Value: {value}")
|
1444
|
+
>>>
|
1445
|
+
>>> env.register_default_behavior('param', callback)
|
1446
|
+
>>>
|
1447
|
+
>>> # Remove the callback
|
1448
|
+
>>> removed = env.unregister_default_behavior('param')
|
1449
|
+
>>> print(f"Callback removed: {removed}") # True
|
1450
|
+
>>>
|
1451
|
+
>>> # No callback triggers now
|
1452
|
+
>>> env.set(param='test') # No output
|
1453
|
+
>>>
|
1454
|
+
>>> # Removing non-existent callback
|
1455
|
+
>>> removed = env.unregister_default_behavior('nonexistent')
|
1456
|
+
>>> print(f"Callback removed: {removed}") # False
|
1457
|
+
"""
|
1458
|
+
if key in _ENV_STATE.functions:
|
1459
|
+
del _ENV_STATE.functions[key]
|
1460
|
+
return True
|
1461
|
+
return False
|
1462
|
+
|
1463
|
+
|
1464
|
+
def list_registered_behaviors() -> List[str]:
|
1465
|
+
"""
|
1466
|
+
List all keys with registered callbacks.
|
1467
|
+
|
1468
|
+
Returns
|
1469
|
+
-------
|
1470
|
+
list of str
|
1471
|
+
Keys that have registered behavior callbacks.
|
1472
|
+
|
1473
|
+
Examples
|
1474
|
+
--------
|
1475
|
+
.. code-block:: python
|
1476
|
+
|
1477
|
+
>>> import brainstate.environ as env
|
1478
|
+
>>>
|
1479
|
+
>>> # Register some callbacks
|
1480
|
+
>>> env.register_default_behavior('param1', lambda x: None)
|
1481
|
+
>>> env.register_default_behavior('param2', lambda x: None)
|
1482
|
+
>>>
|
1483
|
+
>>> # List registered behaviors
|
1484
|
+
>>> behaviors = env.list_registered_behaviors()
|
1485
|
+
>>> print(f"Registered: {behaviors}") # ['param1', 'param2']
|
1486
|
+
>>>
|
1487
|
+
>>> # Check if specific behavior is registered
|
1488
|
+
>>> if 'dt' in behaviors:
|
1489
|
+
... print("dt has a registered callback")
|
555
1490
|
"""
|
556
|
-
|
557
|
-
assert callable(behavior), 'behavior must be a callable.'
|
558
|
-
if not replace_if_exist:
|
559
|
-
assert key not in DFAULT.functions, f'{key} has been registered.'
|
560
|
-
DFAULT.functions[key] = behavior
|
1491
|
+
return list(_ENV_STATE.functions.keys())
|
561
1492
|
|
562
1493
|
|
563
|
-
|
1494
|
+
# Initialize default precision on module load
|
1495
|
+
set(precision=DEFAULT_PRECISION)
|