brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -1,1495 +1,1495 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
Environment configuration and context management for BrainState.
|
18
|
-
|
19
|
-
This module provides comprehensive functionality for managing computational
|
20
|
-
environments, including platform selection, precision control, mode setting,
|
21
|
-
and context-based configuration management. It enables flexible configuration
|
22
|
-
of JAX-based computations with thread-safe context switching.
|
23
|
-
|
24
|
-
The module supports:
|
25
|
-
- Platform configuration (CPU, GPU, TPU)
|
26
|
-
- Precision control (8, 16, 32, 64 bit and bfloat16)
|
27
|
-
- Computation mode management
|
28
|
-
- Context-based temporary settings
|
29
|
-
- Default data type management
|
30
|
-
- Custom behavior registration
|
31
|
-
|
32
|
-
Examples
|
33
|
-
--------
|
34
|
-
Global environment configuration:
|
35
|
-
|
36
|
-
.. code-block:: python
|
37
|
-
|
38
|
-
>>> import brainstate as bs
|
39
|
-
>>> import brainstate.environ as env
|
40
|
-
>>>
|
41
|
-
>>> # Set global precision to 32-bit
|
42
|
-
>>> env.set(precision=32, dt=0.01, mode=bs.mixin.Training())
|
43
|
-
>>>
|
44
|
-
>>> # Get current settings
|
45
|
-
>>> print(env.get('precision')) # 32
|
46
|
-
>>> print(env.get('dt')) # 0.01
|
47
|
-
|
48
|
-
Context-based temporary settings:
|
49
|
-
|
50
|
-
.. code-block:: python
|
51
|
-
|
52
|
-
>>> import brainstate.environ as env
|
53
|
-
>>>
|
54
|
-
>>> # Temporarily change precision
|
55
|
-
>>> with env.context(precision=64, dt=0.001):
|
56
|
-
... high_precision_result = compute_something()
|
57
|
-
... print(env.get('precision')) # 64
|
58
|
-
>>> print(env.get('precision')) # Back to 32
|
59
|
-
"""
|
60
|
-
|
61
|
-
import contextlib
|
62
|
-
import dataclasses
|
63
|
-
import functools
|
64
|
-
import os
|
65
|
-
import re
|
66
|
-
import threading
|
67
|
-
import warnings
|
68
|
-
from collections import defaultdict
|
69
|
-
from typing import Any, Callable, Dict, Hashable, Optional, Union, ContextManager, List
|
70
|
-
|
71
|
-
import brainunit as u
|
72
|
-
import numpy as np
|
73
|
-
from jax import config, devices, numpy as jnp
|
74
|
-
from jax.typing import DTypeLike
|
75
|
-
|
76
|
-
__all__ = [
|
77
|
-
# Core environment management
|
78
|
-
'set',
|
79
|
-
'get',
|
80
|
-
'all',
|
81
|
-
'pop',
|
82
|
-
'context',
|
83
|
-
'reset',
|
84
|
-
|
85
|
-
# Platform and device management
|
86
|
-
'set_platform',
|
87
|
-
'get_platform',
|
88
|
-
'set_host_device_count',
|
89
|
-
'get_host_device_count',
|
90
|
-
|
91
|
-
# Precision and data type management
|
92
|
-
'set_precision',
|
93
|
-
'get_precision',
|
94
|
-
'dftype',
|
95
|
-
'ditype',
|
96
|
-
'dutype',
|
97
|
-
'dctype',
|
98
|
-
|
99
|
-
# Mode and computation settings
|
100
|
-
'get_dt',
|
101
|
-
|
102
|
-
# Utility functions
|
103
|
-
'tolerance',
|
104
|
-
'register_default_behavior',
|
105
|
-
'unregister_default_behavior',
|
106
|
-
'list_registered_behaviors',
|
107
|
-
|
108
|
-
# Constants
|
109
|
-
'DEFAULT_PRECISION',
|
110
|
-
'SUPPORTED_PLATFORMS',
|
111
|
-
'SUPPORTED_PRECISIONS',
|
112
|
-
|
113
|
-
# Names
|
114
|
-
'I',
|
115
|
-
'T',
|
116
|
-
'DT',
|
117
|
-
'PRECISION',
|
118
|
-
'PLATFORM',
|
119
|
-
'HOST_DEVICE_COUNT',
|
120
|
-
'JIT_ERROR_CHECK',
|
121
|
-
'FIT',
|
122
|
-
]
|
123
|
-
|
124
|
-
# Type definitions
|
125
|
-
# T = TypeVar('T')
|
126
|
-
PrecisionType = Union[int, str]
|
127
|
-
PlatformType = str
|
128
|
-
|
129
|
-
# Constants for environment keys
|
130
|
-
I = 'i' # Index of the current computation
|
131
|
-
T = 't' # Current time of the computation
|
132
|
-
DT = 'dt' # Time step for numerical integration
|
133
|
-
PRECISION = 'precision' # Numerical precision
|
134
|
-
PLATFORM = 'platform' # Computing platform
|
135
|
-
HOST_DEVICE_COUNT = 'host_device_count' # Number of host devices
|
136
|
-
JIT_ERROR_CHECK = 'jit_error_check' # JIT error checking flag
|
137
|
-
FIT = 'fit' # Model fitting flag
|
138
|
-
|
139
|
-
# Default values
|
140
|
-
DEFAULT_PRECISION = 32
|
141
|
-
SUPPORTED_PLATFORMS = ('cpu', 'gpu', 'tpu')
|
142
|
-
SUPPORTED_PRECISIONS = (8, 16, 32, 64, 'bf16')
|
143
|
-
|
144
|
-
# Sentinel value for missing arguments
|
145
|
-
_NOT_PROVIDED = object()
|
146
|
-
|
147
|
-
|
148
|
-
@dataclasses.dataclass
|
149
|
-
class EnvironmentState(threading.local):
|
150
|
-
"""
|
151
|
-
Thread-local storage for environment configuration.
|
152
|
-
|
153
|
-
This class maintains separate configuration states for different threads,
|
154
|
-
ensuring thread-safe environment management in concurrent applications.
|
155
|
-
|
156
|
-
Attributes
|
157
|
-
----------
|
158
|
-
settings : Dict[Hashable, Any]
|
159
|
-
Global default environment settings.
|
160
|
-
contexts : defaultdict[Hashable, List[Any]]
|
161
|
-
Stack of context-specific settings for nested contexts.
|
162
|
-
functions : Dict[Hashable, Callable]
|
163
|
-
Registered callback functions for environment changes.
|
164
|
-
locks : Dict[str, threading.Lock]
|
165
|
-
Thread locks for synchronized access to critical sections.
|
166
|
-
"""
|
167
|
-
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
168
|
-
contexts: defaultdict[Hashable, List[Any]] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
169
|
-
functions: Dict[Hashable, Callable] = dataclasses.field(default_factory=dict)
|
170
|
-
locks: Dict[str, threading.Lock] = dataclasses.field(default_factory=lambda: defaultdict(threading.Lock))
|
171
|
-
|
172
|
-
def __post_init__(self):
|
173
|
-
"""Initialize with default settings."""
|
174
|
-
# Set default precision if not already set
|
175
|
-
if PRECISION not in self.settings:
|
176
|
-
self.settings[PRECISION] = DEFAULT_PRECISION
|
177
|
-
|
178
|
-
|
179
|
-
# Global environment state
|
180
|
-
_ENV_STATE = EnvironmentState()
|
181
|
-
|
182
|
-
|
183
|
-
def reset() -> None:
|
184
|
-
"""
|
185
|
-
Reset the environment to default settings.
|
186
|
-
|
187
|
-
This function clears all custom settings and restores the environment
|
188
|
-
to its initial state. Useful for testing or when starting fresh.
|
189
|
-
|
190
|
-
Examples
|
191
|
-
--------
|
192
|
-
.. code-block:: python
|
193
|
-
|
194
|
-
>>> import brainstate.environ as env
|
195
|
-
>>>
|
196
|
-
>>> # Set custom values
|
197
|
-
>>> env.set(dt=0.1, custom_param='value')
|
198
|
-
>>> print(env.get('custom_param')) # 'value'
|
199
|
-
>>>
|
200
|
-
>>> # Reset to defaults
|
201
|
-
>>> env.reset()
|
202
|
-
>>> print(env.get('custom_param', default=None)) # None
|
203
|
-
|
204
|
-
Notes
|
205
|
-
-----
|
206
|
-
This operation cannot be undone. All custom settings will be lost.
|
207
|
-
"""
|
208
|
-
global _ENV_STATE
|
209
|
-
_ENV_STATE = EnvironmentState()
|
210
|
-
# Re-apply default precision
|
211
|
-
_set_jax_precision(DEFAULT_PRECISION)
|
212
|
-
|
213
|
-
warnings.warn(
|
214
|
-
"Environment has been reset to default settings. "
|
215
|
-
"All custom configurations have been cleared.",
|
216
|
-
UserWarning
|
217
|
-
)
|
218
|
-
|
219
|
-
|
220
|
-
@contextlib.contextmanager
|
221
|
-
def context(**kwargs) -> ContextManager[Dict[str, Any]]:
|
222
|
-
"""
|
223
|
-
Context manager for temporary environment settings.
|
224
|
-
|
225
|
-
This context manager allows you to temporarily modify environment settings
|
226
|
-
within a specific scope. Settings are automatically restored when exiting
|
227
|
-
the context, even if an exception occurs.
|
228
|
-
|
229
|
-
Parameters
|
230
|
-
----------
|
231
|
-
**kwargs
|
232
|
-
Environment settings to apply within the context.
|
233
|
-
Common parameters include:
|
234
|
-
|
235
|
-
- precision : int or str.
|
236
|
-
Numerical precision (8, 16, 32, 64, or 'bf16')
|
237
|
-
- dt : float.
|
238
|
-
Time step for numerical integration
|
239
|
-
- mode : :class:`Mode`.
|
240
|
-
Computation mode instance
|
241
|
-
- Any custom parameters registered via register_default_behavior
|
242
|
-
|
243
|
-
Yields
|
244
|
-
------
|
245
|
-
dict
|
246
|
-
Current environment settings within the context.
|
247
|
-
|
248
|
-
Raises
|
249
|
-
------
|
250
|
-
ValueError
|
251
|
-
If attempting to set platform or host_device_count in context
|
252
|
-
(these must be set globally).
|
253
|
-
TypeError
|
254
|
-
If invalid parameter types are provided.
|
255
|
-
|
256
|
-
Examples
|
257
|
-
--------
|
258
|
-
Basic usage with precision control:
|
259
|
-
|
260
|
-
.. code-block:: python
|
261
|
-
|
262
|
-
>>> import brainstate.environ as env
|
263
|
-
>>>
|
264
|
-
>>> # Set global precision
|
265
|
-
>>> env.set(precision=32)
|
266
|
-
>>>
|
267
|
-
>>> # Temporarily use higher precision
|
268
|
-
>>> with env.context(precision=64) as ctx:
|
269
|
-
... print(f"Precision in context: {env.get('precision')}") # 64
|
270
|
-
... print(f"Float type: {env.dftype()}") # float64
|
271
|
-
>>>
|
272
|
-
>>> print(f"Precision after context: {env.get('precision')}") # 32
|
273
|
-
|
274
|
-
Nested contexts:
|
275
|
-
|
276
|
-
.. code-block:: python
|
277
|
-
|
278
|
-
>>> import brainstate.environ as env
|
279
|
-
>>>
|
280
|
-
>>> with env.context(dt=0.1) as ctx1:
|
281
|
-
... print(f"dt = {env.get('dt')}") # 0.1
|
282
|
-
...
|
283
|
-
... with env.context(dt=0.01) as ctx2:
|
284
|
-
... print(f"dt = {env.get('dt')}") # 0.01
|
285
|
-
...
|
286
|
-
... print(f"dt = {env.get('dt')}") # 0.1
|
287
|
-
|
288
|
-
Error handling in context:
|
289
|
-
|
290
|
-
.. code-block:: python
|
291
|
-
|
292
|
-
>>> import brainstate.environ as env
|
293
|
-
>>>
|
294
|
-
>>> env.set(value=10)
|
295
|
-
>>> try:
|
296
|
-
... with env.context(value=20):
|
297
|
-
... print(env.get('value')) # 20
|
298
|
-
... raise ValueError("Something went wrong")
|
299
|
-
... except ValueError:
|
300
|
-
... pass
|
301
|
-
>>>
|
302
|
-
>>> print(env.get('value')) # 10 (restored)
|
303
|
-
|
304
|
-
Notes
|
305
|
-
-----
|
306
|
-
- Platform and host_device_count cannot be set in context
|
307
|
-
- Contexts can be nested arbitrarily deep
|
308
|
-
- Settings are restored in reverse order when exiting
|
309
|
-
- Thread-safe: each thread maintains its own context stack
|
310
|
-
"""
|
311
|
-
# Validate restricted parameters
|
312
|
-
if PLATFORM in kwargs:
|
313
|
-
raise ValueError(
|
314
|
-
f"Cannot set '{PLATFORM}' in context. "
|
315
|
-
f"Use set_platform() or set() for global configuration."
|
316
|
-
)
|
317
|
-
if HOST_DEVICE_COUNT in kwargs:
|
318
|
-
raise ValueError(
|
319
|
-
f"Cannot set '{HOST_DEVICE_COUNT}' in context. "
|
320
|
-
f"Use set_host_device_count() or set() for global configuration."
|
321
|
-
)
|
322
|
-
|
323
|
-
# Handle precision changes
|
324
|
-
original_precision = None
|
325
|
-
if PRECISION in kwargs:
|
326
|
-
original_precision = _get_precision()
|
327
|
-
_validate_precision(kwargs[PRECISION])
|
328
|
-
_set_jax_precision(kwargs[PRECISION])
|
329
|
-
|
330
|
-
try:
|
331
|
-
# Push new values onto context stacks
|
332
|
-
for key, value in kwargs.items():
|
333
|
-
with _ENV_STATE.locks[key]:
|
334
|
-
_ENV_STATE.contexts[key].append(value)
|
335
|
-
|
336
|
-
# Trigger registered callbacks
|
337
|
-
if key in _ENV_STATE.functions:
|
338
|
-
try:
|
339
|
-
_ENV_STATE.functions[key](value)
|
340
|
-
except Exception as e:
|
341
|
-
warnings.warn(
|
342
|
-
f"Callback for '{key}' raised an exception: {e}",
|
343
|
-
RuntimeWarning
|
344
|
-
)
|
345
|
-
|
346
|
-
# Yield current environment state
|
347
|
-
yield all()
|
348
|
-
|
349
|
-
finally:
|
350
|
-
# Restore previous values
|
351
|
-
for key in kwargs:
|
352
|
-
with _ENV_STATE.locks[key]:
|
353
|
-
if _ENV_STATE.contexts[key]:
|
354
|
-
_ENV_STATE.contexts[key].pop()
|
355
|
-
|
356
|
-
# Restore callbacks with previous value
|
357
|
-
if key in _ENV_STATE.functions:
|
358
|
-
try:
|
359
|
-
prev_value = get(key, default=None)
|
360
|
-
if prev_value is not None:
|
361
|
-
_ENV_STATE.functions[key](prev_value)
|
362
|
-
except Exception as e:
|
363
|
-
warnings.warn(
|
364
|
-
f"Callback restoration for '{key}' raised: {e}",
|
365
|
-
RuntimeWarning
|
366
|
-
)
|
367
|
-
|
368
|
-
# Restore precision if it was changed
|
369
|
-
if original_precision is not None:
|
370
|
-
_set_jax_precision(original_precision)
|
371
|
-
|
372
|
-
|
373
|
-
def get(key: str, default: Any = _NOT_PROVIDED, desc: Optional[str] = None) -> Any:
|
374
|
-
"""
|
375
|
-
Get a value from the current environment.
|
376
|
-
|
377
|
-
This function retrieves values from the environment, checking first in
|
378
|
-
the context stack, then in global settings. Special handling is provided
|
379
|
-
for platform and device count parameters.
|
380
|
-
|
381
|
-
Parameters
|
382
|
-
----------
|
383
|
-
key : str
|
384
|
-
The environment key to retrieve.
|
385
|
-
default : Any, optional
|
386
|
-
Default value to return if key is not found.
|
387
|
-
If not provided, raises KeyError for missing keys.
|
388
|
-
desc : str, optional
|
389
|
-
Description of the parameter for error messages.
|
390
|
-
|
391
|
-
Returns
|
392
|
-
-------
|
393
|
-
Any
|
394
|
-
The value associated with the key.
|
395
|
-
|
396
|
-
Raises
|
397
|
-
------
|
398
|
-
KeyError
|
399
|
-
If key is not found and no default is provided.
|
400
|
-
|
401
|
-
Examples
|
402
|
-
--------
|
403
|
-
Basic retrieval:
|
404
|
-
|
405
|
-
.. code-block:: python
|
406
|
-
|
407
|
-
>>> import brainstate.environ as env
|
408
|
-
>>>
|
409
|
-
>>> env.set(learning_rate=0.001)
|
410
|
-
>>> lr = env.get('learning_rate')
|
411
|
-
>>> print(f"Learning rate: {lr}") # 0.001
|
412
|
-
|
413
|
-
With default value:
|
414
|
-
|
415
|
-
.. code-block:: python
|
416
|
-
|
417
|
-
>>> import brainstate.environ as env
|
418
|
-
>>>
|
419
|
-
>>> # Get with default
|
420
|
-
>>> batch_size = env.get('batch_size', default=32)
|
421
|
-
>>> print(f"Batch size: {batch_size}") # 32
|
422
|
-
|
423
|
-
Context-aware retrieval:
|
424
|
-
|
425
|
-
.. code-block:: python
|
426
|
-
|
427
|
-
>>> import brainstate.environ as env
|
428
|
-
>>>
|
429
|
-
>>> env.set(temperature=1.0)
|
430
|
-
>>> print(env.get('temperature')) # 1.0
|
431
|
-
>>>
|
432
|
-
>>> with env.context(temperature=0.5):
|
433
|
-
... print(env.get('temperature')) # 0.5
|
434
|
-
>>>
|
435
|
-
>>> print(env.get('temperature')) # 1.0
|
436
|
-
|
437
|
-
Notes
|
438
|
-
-----
|
439
|
-
Special keys 'platform' and 'host_device_count' are handled separately
|
440
|
-
and retrieve system-level information.
|
441
|
-
"""
|
442
|
-
# Special cases for platform-specific parameters
|
443
|
-
if key == PLATFORM:
|
444
|
-
return get_platform()
|
445
|
-
if key == HOST_DEVICE_COUNT:
|
446
|
-
return get_host_device_count()
|
447
|
-
|
448
|
-
# Check context stack first (most recent value)
|
449
|
-
with _ENV_STATE.locks[key]:
|
450
|
-
if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
|
451
|
-
return _ENV_STATE.contexts[key][-1]
|
452
|
-
|
453
|
-
# Check global settings
|
454
|
-
if key in _ENV_STATE.settings:
|
455
|
-
return _ENV_STATE.settings[key]
|
456
|
-
|
457
|
-
# Handle missing key
|
458
|
-
if default is _NOT_PROVIDED:
|
459
|
-
error_msg = f"Key '{key}' not found in environment."
|
460
|
-
if desc:
|
461
|
-
error_msg += f" Description: {desc}"
|
462
|
-
error_msg += (
|
463
|
-
f"\nSet it using:\n"
|
464
|
-
f" - env.set({key}=value) for global setting\n"
|
465
|
-
f" - env.context({key}=value) for temporary setting"
|
466
|
-
)
|
467
|
-
raise KeyError(error_msg)
|
468
|
-
|
469
|
-
return default
|
470
|
-
|
471
|
-
|
472
|
-
def all() -> Dict[str, Any]:
|
473
|
-
"""
|
474
|
-
Get all current environment settings.
|
475
|
-
|
476
|
-
This function returns a dictionary containing all active environment
|
477
|
-
settings, with context values taking precedence over global settings.
|
478
|
-
|
479
|
-
Returns
|
480
|
-
-------
|
481
|
-
dict
|
482
|
-
Dictionary of all current environment settings.
|
483
|
-
|
484
|
-
Examples
|
485
|
-
--------
|
486
|
-
.. code-block:: python
|
487
|
-
|
488
|
-
>>> import brainstate.environ as env
|
489
|
-
>>>
|
490
|
-
>>> # Set various parameters
|
491
|
-
>>> env.set(precision=32, dt=0.01, debug=True)
|
492
|
-
>>>
|
493
|
-
>>> # Get all settings
|
494
|
-
>>> settings = env.all()
|
495
|
-
>>> print(settings)
|
496
|
-
{'precision': 32, 'dt': 0.01, 'debug': True}
|
497
|
-
|
498
|
-
>>> # Context overrides
|
499
|
-
>>> with env.context(precision=64, new_param='test'):
|
500
|
-
... settings = env.all()
|
501
|
-
... print(settings['precision']) # 64
|
502
|
-
... print(settings['new_param']) # 'test'
|
503
|
-
|
504
|
-
Notes
|
505
|
-
-----
|
506
|
-
The returned dictionary is a snapshot and modifying it does not
|
507
|
-
affect the environment settings.
|
508
|
-
"""
|
509
|
-
result = {}
|
510
|
-
|
511
|
-
# Add global settings
|
512
|
-
result.update(_ENV_STATE.settings)
|
513
|
-
|
514
|
-
# Override with context values (most recent)
|
515
|
-
for key, values in _ENV_STATE.contexts.items():
|
516
|
-
if values:
|
517
|
-
result[key] = values[-1]
|
518
|
-
|
519
|
-
return result
|
520
|
-
|
521
|
-
|
522
|
-
def pop(key: str, default: Any = _NOT_PROVIDED) -> Any:
|
523
|
-
"""
|
524
|
-
Remove and return a value from the global environment.
|
525
|
-
|
526
|
-
This function removes a key from the global environment settings and
|
527
|
-
returns its value. If the key is not found, it returns the default
|
528
|
-
value if provided, or raises KeyError.
|
529
|
-
|
530
|
-
Note that this function only affects global settings, not context values.
|
531
|
-
Keys in active contexts are not affected.
|
532
|
-
|
533
|
-
Parameters
|
534
|
-
----------
|
535
|
-
key : str
|
536
|
-
The environment key to remove.
|
537
|
-
default : Any, optional
|
538
|
-
Default value to return if key is not found.
|
539
|
-
If not provided, raises KeyError for missing keys.
|
540
|
-
|
541
|
-
Returns
|
542
|
-
-------
|
543
|
-
Any
|
544
|
-
The value that was removed from the environment.
|
545
|
-
|
546
|
-
Raises
|
547
|
-
------
|
548
|
-
KeyError
|
549
|
-
If key is not found and no default is provided.
|
550
|
-
ValueError
|
551
|
-
If attempting to pop a key that is currently in a context.
|
552
|
-
|
553
|
-
Examples
|
554
|
-
--------
|
555
|
-
Basic usage:
|
556
|
-
|
557
|
-
.. code-block:: python
|
558
|
-
|
559
|
-
>>> import brainstate.environ as env
|
560
|
-
>>>
|
561
|
-
>>> # Set a value
|
562
|
-
>>> env.set(temp_param='temporary')
|
563
|
-
>>> print(env.get('temp_param')) # 'temporary'
|
564
|
-
>>>
|
565
|
-
>>> # Pop the value
|
566
|
-
>>> value = env.pop('temp_param')
|
567
|
-
>>> print(value) # 'temporary'
|
568
|
-
>>>
|
569
|
-
>>> # Value is now gone
|
570
|
-
>>> env.get('temp_param', default=None) # None
|
571
|
-
|
572
|
-
With default value:
|
573
|
-
|
574
|
-
.. code-block:: python
|
575
|
-
|
576
|
-
>>> import brainstate.environ as env
|
577
|
-
>>>
|
578
|
-
>>> # Pop non-existent key with default
|
579
|
-
>>> value = env.pop('nonexistent', default='default_value')
|
580
|
-
>>> print(value) # 'default_value'
|
581
|
-
|
582
|
-
Pop multiple values:
|
583
|
-
|
584
|
-
.. code-block:: python
|
585
|
-
|
586
|
-
>>> import brainstate.environ as env
|
587
|
-
>>>
|
588
|
-
>>> # Set multiple values
|
589
|
-
>>> env.set(param1='value1', param2='value2', param3='value3')
|
590
|
-
>>>
|
591
|
-
>>> # Pop them one by one
|
592
|
-
>>> v1 = env.pop('param1')
|
593
|
-
>>> v2 = env.pop('param2')
|
594
|
-
>>>
|
595
|
-
>>> # param3 still exists
|
596
|
-
>>> print(env.get('param3')) # 'value3'
|
597
|
-
|
598
|
-
Context protection:
|
599
|
-
|
600
|
-
.. code-block:: python
|
601
|
-
|
602
|
-
>>> import brainstate.environ as env
|
603
|
-
>>>
|
604
|
-
>>> env.set(protected='global_value')
|
605
|
-
>>>
|
606
|
-
>>> with env.context(protected='context_value'):
|
607
|
-
... # Cannot pop a key that's in active context
|
608
|
-
... try:
|
609
|
-
... env.pop('protected')
|
610
|
-
... except ValueError as e:
|
611
|
-
... print("Cannot pop key in active context")
|
612
|
-
|
613
|
-
Notes
|
614
|
-
-----
|
615
|
-
- This function only removes keys from global settings
|
616
|
-
- Keys that are currently overridden in active contexts cannot be popped
|
617
|
-
- Special keys like 'platform' and 'host_device_count' can be popped but
|
618
|
-
their system-level values remain accessible through get_platform() etc.
|
619
|
-
- Registered callbacks are NOT triggered when popping values
|
620
|
-
"""
|
621
|
-
# Check if key is currently in any active context
|
622
|
-
if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
|
623
|
-
raise ValueError(
|
624
|
-
f"Cannot pop key '{key}' while it is active in a context. "
|
625
|
-
f"The key is currently overridden in {len(_ENV_STATE.contexts[key])} context(s)."
|
626
|
-
)
|
627
|
-
|
628
|
-
# Check if key exists in global settings
|
629
|
-
if key in _ENV_STATE.settings:
|
630
|
-
# Remove and return the value
|
631
|
-
value = _ENV_STATE.settings.pop(key)
|
632
|
-
|
633
|
-
# Note: We don't trigger callbacks here as this is a removal operation
|
634
|
-
# If needed, users can register callbacks for removal separately
|
635
|
-
|
636
|
-
return value
|
637
|
-
|
638
|
-
# Key not found, handle default
|
639
|
-
if default is _NOT_PROVIDED:
|
640
|
-
raise KeyError(f"Key '{key}' not found in global environment settings.")
|
641
|
-
|
642
|
-
return default
|
643
|
-
|
644
|
-
|
645
|
-
def set(
|
646
|
-
platform: Optional[PlatformType] = None,
|
647
|
-
host_device_count: Optional[int] = None,
|
648
|
-
precision: Optional[PrecisionType] = None,
|
649
|
-
dt: Optional[float] = None,
|
650
|
-
**kwargs
|
651
|
-
) -> None:
|
652
|
-
"""
|
653
|
-
Set global environment configuration.
|
654
|
-
|
655
|
-
This function sets persistent global environment settings that remain
|
656
|
-
active until explicitly changed or the program terminates.
|
657
|
-
|
658
|
-
Parameters
|
659
|
-
----------
|
660
|
-
platform : str, optional
|
661
|
-
Computing platform ('cpu', 'gpu', or 'tpu').
|
662
|
-
host_device_count : int, optional
|
663
|
-
Number of host devices for parallel computation.
|
664
|
-
precision : int or str, optional
|
665
|
-
Numerical precision (8, 16, 32, 64, or 'bf16').
|
666
|
-
mode : Mode, optional
|
667
|
-
Computation mode instance.
|
668
|
-
dt : float, optional
|
669
|
-
Time step for numerical integration.
|
670
|
-
**kwargs
|
671
|
-
Additional custom environment parameters.
|
672
|
-
|
673
|
-
Raises
|
674
|
-
------
|
675
|
-
ValueError
|
676
|
-
If invalid platform or precision is specified.
|
677
|
-
TypeError
|
678
|
-
If mode is not a Mode instance.
|
679
|
-
|
680
|
-
Examples
|
681
|
-
--------
|
682
|
-
Basic configuration:
|
683
|
-
|
684
|
-
.. code-block:: python
|
685
|
-
|
686
|
-
>>> import brainstate as bs
|
687
|
-
>>> import brainstate.environ as env
|
688
|
-
>>>
|
689
|
-
>>> # Set multiple parameters
|
690
|
-
>>> env.set(
|
691
|
-
... precision=32,
|
692
|
-
... dt=0.01,
|
693
|
-
... mode=bs.mixin.Training(),
|
694
|
-
... debug=False
|
695
|
-
... )
|
696
|
-
>>>
|
697
|
-
>>> print(env.get('precision')) # 32
|
698
|
-
>>> print(env.get('dt')) # 0.01
|
699
|
-
|
700
|
-
Platform configuration:
|
701
|
-
|
702
|
-
.. code-block:: python
|
703
|
-
|
704
|
-
>>> import brainstate.environ as env
|
705
|
-
>>>
|
706
|
-
>>> # Configure for GPU computation
|
707
|
-
>>> env.set(platform='gpu', precision=16)
|
708
|
-
>>>
|
709
|
-
>>> # Configure for multi-core CPU
|
710
|
-
>>> env.set(platform='cpu', host_device_count=4)
|
711
|
-
|
712
|
-
Custom parameters:
|
713
|
-
|
714
|
-
.. code-block:: python
|
715
|
-
|
716
|
-
>>> import brainstate.environ as env
|
717
|
-
>>>
|
718
|
-
>>> # Set custom parameters
|
719
|
-
>>> env.set(
|
720
|
-
... experiment_name='test_001',
|
721
|
-
... random_seed=42,
|
722
|
-
... log_level='DEBUG'
|
723
|
-
... )
|
724
|
-
>>>
|
725
|
-
>>> # Retrieve custom parameters
|
726
|
-
>>> print(env.get('experiment_name')) # 'test_001'
|
727
|
-
|
728
|
-
Notes
|
729
|
-
-----
|
730
|
-
- Platform changes only take effect at program start
|
731
|
-
- Some JAX configurations require restart to take effect
|
732
|
-
- Custom parameters can be any hashable key-value pairs
|
733
|
-
"""
|
734
|
-
# Handle special parameters
|
735
|
-
if platform is not None:
|
736
|
-
set_platform(platform)
|
737
|
-
|
738
|
-
if host_device_count is not None:
|
739
|
-
set_host_device_count(host_device_count)
|
740
|
-
|
741
|
-
if precision is not None:
|
742
|
-
_validate_precision(precision)
|
743
|
-
_set_jax_precision(precision)
|
744
|
-
kwargs[PRECISION] = precision
|
745
|
-
|
746
|
-
if dt is not None:
|
747
|
-
if not u.math.isscalar(dt):
|
748
|
-
raise TypeError(f"'{DT}' must be a scalar number, got {type(dt)}")
|
749
|
-
kwargs[DT] = dt
|
750
|
-
|
751
|
-
# Update global settings
|
752
|
-
_ENV_STATE.settings.update(kwargs)
|
753
|
-
|
754
|
-
# Trigger registered callbacks
|
755
|
-
for key, value in kwargs.items():
|
756
|
-
if key in _ENV_STATE.functions:
|
757
|
-
try:
|
758
|
-
_ENV_STATE.functions[key](value)
|
759
|
-
except Exception as e:
|
760
|
-
warnings.warn(
|
761
|
-
f"Callback for '{key}' raised an exception: {e}",
|
762
|
-
RuntimeWarning
|
763
|
-
)
|
764
|
-
|
765
|
-
|
766
|
-
def get_dt() -> float:
|
767
|
-
"""
|
768
|
-
Get the current numerical integration time step.
|
769
|
-
|
770
|
-
Returns
|
771
|
-
-------
|
772
|
-
float
|
773
|
-
The time step value.
|
774
|
-
|
775
|
-
Raises
|
776
|
-
------
|
777
|
-
KeyError
|
778
|
-
If dt is not set.
|
779
|
-
|
780
|
-
Examples
|
781
|
-
--------
|
782
|
-
.. code-block:: python
|
783
|
-
|
784
|
-
>>> import brainstate.environ as env
|
785
|
-
>>>
|
786
|
-
>>> env.set(dt=0.01)
|
787
|
-
>>> dt = env.get_dt()
|
788
|
-
>>> print(f"Time step: {dt} ms") # Time step: 0.01 ms
|
789
|
-
>>>
|
790
|
-
>>> # Use in computation
|
791
|
-
>>> with env.context(dt=0.001):
|
792
|
-
... fine_dt = env.get_dt()
|
793
|
-
... print(f"Fine time step: {fine_dt}") # 0.001
|
794
|
-
"""
|
795
|
-
return get(DT)
|
796
|
-
|
797
|
-
|
798
|
-
def get_platform() -> PlatformType:
|
799
|
-
"""
|
800
|
-
Get the current computing platform.
|
801
|
-
|
802
|
-
Returns
|
803
|
-
-------
|
804
|
-
str
|
805
|
-
Platform name ('cpu', 'gpu', or 'tpu').
|
806
|
-
|
807
|
-
Examples
|
808
|
-
--------
|
809
|
-
.. code-block:: python
|
810
|
-
|
811
|
-
>>> import brainstate.environ as env
|
812
|
-
>>>
|
813
|
-
>>> platform = env.get_platform()
|
814
|
-
>>> print(f"Running on: {platform}")
|
815
|
-
>>>
|
816
|
-
>>> if platform == 'gpu':
|
817
|
-
... print("GPU acceleration available")
|
818
|
-
... else:
|
819
|
-
... print(f"Using {platform.upper()}")
|
820
|
-
"""
|
821
|
-
return devices()[0].platform
|
822
|
-
|
823
|
-
|
824
|
-
def get_host_device_count() -> int:
|
825
|
-
"""
|
826
|
-
Get the number of host devices.
|
827
|
-
|
828
|
-
Returns
|
829
|
-
-------
|
830
|
-
int
|
831
|
-
Number of host devices configured.
|
832
|
-
|
833
|
-
Examples
|
834
|
-
--------
|
835
|
-
.. code-block:: python
|
836
|
-
|
837
|
-
>>> import brainstate.environ as env
|
838
|
-
>>>
|
839
|
-
>>> # Get device count
|
840
|
-
>>> n_devices = env.get_host_device_count()
|
841
|
-
>>> print(f"Host devices: {n_devices}")
|
842
|
-
>>>
|
843
|
-
>>> # Configure for parallel computation
|
844
|
-
>>> if n_devices > 1:
|
845
|
-
... print(f"Can use {n_devices} devices for parallel computation")
|
846
|
-
"""
|
847
|
-
xla_flags = os.getenv("XLA_FLAGS", "")
|
848
|
-
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
849
|
-
return int(match.group(1)) if match else 1
|
850
|
-
|
851
|
-
|
852
|
-
def set_platform(platform: PlatformType) -> None:
|
853
|
-
"""
|
854
|
-
Set the computing platform.
|
855
|
-
|
856
|
-
Parameters
|
857
|
-
----------
|
858
|
-
platform : str
|
859
|
-
Platform to use ('cpu', 'gpu', or 'tpu').
|
860
|
-
|
861
|
-
Raises
|
862
|
-
------
|
863
|
-
ValueError
|
864
|
-
If platform is not supported.
|
865
|
-
|
866
|
-
Examples
|
867
|
-
--------
|
868
|
-
.. code-block:: python
|
869
|
-
|
870
|
-
>>> import brainstate.environ as env
|
871
|
-
>>>
|
872
|
-
>>> # Set to GPU
|
873
|
-
>>> env.set_platform('gpu')
|
874
|
-
>>>
|
875
|
-
>>> # Verify platform
|
876
|
-
>>> print(env.get_platform()) # 'gpu'
|
877
|
-
|
878
|
-
Notes
|
879
|
-
-----
|
880
|
-
Platform changes only take effect at program start. Changing platform
|
881
|
-
after JAX initialization may not have the expected effect.
|
882
|
-
"""
|
883
|
-
if platform not in SUPPORTED_PLATFORMS:
|
884
|
-
raise ValueError(
|
885
|
-
f"Platform must be one of {SUPPORTED_PLATFORMS}, got '{platform}'"
|
886
|
-
)
|
887
|
-
|
888
|
-
config.update("jax_platform_name", platform)
|
889
|
-
|
890
|
-
# Trigger callbacks
|
891
|
-
if PLATFORM in _ENV_STATE.functions:
|
892
|
-
_ENV_STATE.functions[PLATFORM](platform)
|
893
|
-
|
894
|
-
|
895
|
-
def set_host_device_count(n: int) -> None:
|
896
|
-
"""
|
897
|
-
Set the number of host (CPU) devices.
|
898
|
-
|
899
|
-
This function configures XLA to treat CPU cores as separate devices,
|
900
|
-
enabling parallel computation with jax.pmap on CPU.
|
901
|
-
|
902
|
-
Parameters
|
903
|
-
----------
|
904
|
-
n : int
|
905
|
-
Number of host devices to configure.
|
906
|
-
|
907
|
-
Raises
|
908
|
-
------
|
909
|
-
ValueError
|
910
|
-
If n is not a positive integer.
|
911
|
-
|
912
|
-
Examples
|
913
|
-
--------
|
914
|
-
.. code-block:: python
|
915
|
-
|
916
|
-
>>> import brainstate.environ as env
|
917
|
-
>>> import jax
|
918
|
-
>>>
|
919
|
-
>>> # Configure 4 CPU devices
|
920
|
-
>>> env.set_host_device_count(4)
|
921
|
-
>>>
|
922
|
-
>>> # Use with pmap
|
923
|
-
>>> def parallel_fn(x):
|
924
|
-
... return x * 2
|
925
|
-
>>>
|
926
|
-
>>> # This will work with 4 devices
|
927
|
-
>>> pmapped_fn = jax.pmap(parallel_fn)
|
928
|
-
|
929
|
-
Warnings
|
930
|
-
--------
|
931
|
-
This setting only takes effect at program start. The effects of using
|
932
|
-
xla_force_host_platform_device_count are not fully understood and may
|
933
|
-
cause unexpected behavior.
|
934
|
-
"""
|
935
|
-
if not isinstance(n, int) or n < 1:
|
936
|
-
raise ValueError(f"Host device count must be a positive integer, got {n}")
|
937
|
-
|
938
|
-
# Update XLA flags
|
939
|
-
xla_flags = os.getenv("XLA_FLAGS", "")
|
940
|
-
xla_flags = re.sub(
|
941
|
-
r"--xla_force_host_platform_device_count=\S+",
|
942
|
-
"",
|
943
|
-
xla_flags
|
944
|
-
).split()
|
945
|
-
|
946
|
-
os.environ["XLA_FLAGS"] = " ".join(
|
947
|
-
[f"--xla_force_host_platform_device_count={n}"] + xla_flags
|
948
|
-
)
|
949
|
-
|
950
|
-
# Trigger callbacks
|
951
|
-
if HOST_DEVICE_COUNT in _ENV_STATE.functions:
|
952
|
-
_ENV_STATE.functions[HOST_DEVICE_COUNT](n)
|
953
|
-
|
954
|
-
|
955
|
-
def set_precision(precision: PrecisionType) -> None:
|
956
|
-
"""
|
957
|
-
Set the global numerical precision.
|
958
|
-
|
959
|
-
Parameters
|
960
|
-
----------
|
961
|
-
precision : int or str
|
962
|
-
Precision to use (8, 16, 32, 64, or 'bf16').
|
963
|
-
|
964
|
-
Raises
|
965
|
-
------
|
966
|
-
ValueError
|
967
|
-
If precision is not supported.
|
968
|
-
|
969
|
-
Examples
|
970
|
-
--------
|
971
|
-
.. code-block:: python
|
972
|
-
|
973
|
-
>>> import brainstate.environ as env
|
974
|
-
>>> import jax.numpy as jnp
|
975
|
-
>>>
|
976
|
-
>>> # Set to 64-bit precision
|
977
|
-
>>> env.set_precision(64)
|
978
|
-
>>>
|
979
|
-
>>> # Arrays will use float64 by default
|
980
|
-
>>> x = jnp.array([1.0, 2.0, 3.0])
|
981
|
-
>>> print(x.dtype) # float64
|
982
|
-
>>>
|
983
|
-
>>> # Set to bfloat16 for efficiency
|
984
|
-
>>> env.set_precision('bf16')
|
985
|
-
"""
|
986
|
-
_validate_precision(precision)
|
987
|
-
_set_jax_precision(precision)
|
988
|
-
_ENV_STATE.settings[PRECISION] = precision
|
989
|
-
|
990
|
-
# Trigger callbacks
|
991
|
-
if PRECISION in _ENV_STATE.functions:
|
992
|
-
_ENV_STATE.functions[PRECISION](precision)
|
993
|
-
|
994
|
-
|
995
|
-
def get_precision() -> int:
|
996
|
-
"""
|
997
|
-
Get the current numerical precision as an integer.
|
998
|
-
|
999
|
-
Returns
|
1000
|
-
-------
|
1001
|
-
int
|
1002
|
-
Precision in bits (8, 16, 32, or 64).
|
1003
|
-
|
1004
|
-
Examples
|
1005
|
-
--------
|
1006
|
-
.. code-block:: python
|
1007
|
-
|
1008
|
-
>>> import brainstate.environ as env
|
1009
|
-
>>>
|
1010
|
-
>>> env.set_precision(32)
|
1011
|
-
>>> bits = env.get_precision()
|
1012
|
-
>>> print(f"Using {bits}-bit precision") # Using 32-bit precision
|
1013
|
-
>>>
|
1014
|
-
>>> # Special handling for bfloat16
|
1015
|
-
>>> env.set_precision('bf16')
|
1016
|
-
>>> print(env.get_precision()) # 16
|
1017
|
-
|
1018
|
-
Notes
|
1019
|
-
-----
|
1020
|
-
'bf16' (bfloat16) is reported as 16-bit precision.
|
1021
|
-
"""
|
1022
|
-
precision = get(PRECISION, default=DEFAULT_PRECISION)
|
1023
|
-
|
1024
|
-
if precision == 'bf16':
|
1025
|
-
return 16
|
1026
|
-
elif isinstance(precision, str):
|
1027
|
-
return int(precision)
|
1028
|
-
elif isinstance(precision, int):
|
1029
|
-
return precision
|
1030
|
-
else:
|
1031
|
-
raise ValueError(f"Invalid precision type: {type(precision)}")
|
1032
|
-
|
1033
|
-
|
1034
|
-
def _validate_precision(precision: PrecisionType) -> None:
|
1035
|
-
"""Validate precision value."""
|
1036
|
-
if precision not in SUPPORTED_PRECISIONS and str(precision) not in map(str, SUPPORTED_PRECISIONS):
|
1037
|
-
raise ValueError(
|
1038
|
-
f"Precision must be one of {SUPPORTED_PRECISIONS}, got {precision}"
|
1039
|
-
)
|
1040
|
-
|
1041
|
-
|
1042
|
-
def _get_precision() -> PrecisionType:
|
1043
|
-
"""Get raw precision value (including 'bf16')."""
|
1044
|
-
return get(PRECISION, default=DEFAULT_PRECISION)
|
1045
|
-
|
1046
|
-
|
1047
|
-
def _set_jax_precision(precision: PrecisionType) -> None:
|
1048
|
-
"""Configure JAX precision settings."""
|
1049
|
-
# Enable/disable 64-bit mode
|
1050
|
-
if precision in (64, '64'):
|
1051
|
-
config.update("jax_enable_x64", True)
|
1052
|
-
else:
|
1053
|
-
config.update("jax_enable_x64", False)
|
1054
|
-
|
1055
|
-
|
1056
|
-
@functools.lru_cache(maxsize=16)
|
1057
|
-
def _get_uint(precision: PrecisionType) -> DTypeLike:
|
1058
|
-
"""Get unsigned integer type for given precision."""
|
1059
|
-
if precision in (64, '64'):
|
1060
|
-
return np.uint64
|
1061
|
-
elif precision in (32, '32'):
|
1062
|
-
return np.uint32
|
1063
|
-
elif precision in (16, '16', 'bf16'):
|
1064
|
-
return np.uint16
|
1065
|
-
elif precision in (8, '8'):
|
1066
|
-
return np.uint8
|
1067
|
-
else:
|
1068
|
-
raise ValueError(f"Unsupported precision: {precision}")
|
1069
|
-
|
1070
|
-
|
1071
|
-
@functools.lru_cache(maxsize=16)
|
1072
|
-
def _get_int(precision: PrecisionType) -> DTypeLike:
|
1073
|
-
"""Get integer type for given precision."""
|
1074
|
-
if precision in (64, '64'):
|
1075
|
-
return np.int64
|
1076
|
-
elif precision in (32, '32'):
|
1077
|
-
return np.int32
|
1078
|
-
elif precision in (16, '16', 'bf16'):
|
1079
|
-
return np.int16
|
1080
|
-
elif precision in (8, '8'):
|
1081
|
-
return np.int8
|
1082
|
-
else:
|
1083
|
-
raise ValueError(f"Unsupported precision: {precision}")
|
1084
|
-
|
1085
|
-
|
1086
|
-
@functools.lru_cache(maxsize=16)
|
1087
|
-
def _get_float(precision: PrecisionType) -> DTypeLike:
|
1088
|
-
"""Get floating-point type for given precision."""
|
1089
|
-
if precision in (64, '64'):
|
1090
|
-
return np.float64
|
1091
|
-
elif precision in (32, '32'):
|
1092
|
-
return np.float32
|
1093
|
-
elif precision in (16, '16'):
|
1094
|
-
return np.float16
|
1095
|
-
elif precision == 'bf16':
|
1096
|
-
return jnp.bfloat16
|
1097
|
-
elif precision in (8, '8'):
|
1098
|
-
return jnp.float8_e5m2
|
1099
|
-
else:
|
1100
|
-
raise ValueError(f"Unsupported precision: {precision}")
|
1101
|
-
|
1102
|
-
|
1103
|
-
@functools.lru_cache(maxsize=16)
|
1104
|
-
def _get_complex(precision: PrecisionType) -> DTypeLike:
|
1105
|
-
"""Get complex type for given precision."""
|
1106
|
-
if precision in (64, '64'):
|
1107
|
-
return np.complex128
|
1108
|
-
elif precision in (32, '32', 16, '16', 'bf16', 8, '8'):
|
1109
|
-
return np.complex64
|
1110
|
-
else:
|
1111
|
-
raise ValueError(f"Unsupported precision: {precision}")
|
1112
|
-
|
1113
|
-
|
1114
|
-
def dftype() -> DTypeLike:
|
1115
|
-
"""
|
1116
|
-
Get the default floating-point data type.
|
1117
|
-
|
1118
|
-
This function returns the appropriate floating-point type based on
|
1119
|
-
the current precision setting, allowing dynamic type selection.
|
1120
|
-
|
1121
|
-
Returns
|
1122
|
-
-------
|
1123
|
-
DTypeLike
|
1124
|
-
Default floating-point data type.
|
1125
|
-
|
1126
|
-
Examples
|
1127
|
-
--------
|
1128
|
-
.. code-block:: python
|
1129
|
-
|
1130
|
-
>>> import brainstate.environ as env
|
1131
|
-
>>> import jax.numpy as jnp
|
1132
|
-
>>>
|
1133
|
-
>>> # With 32-bit precision
|
1134
|
-
>>> env.set(precision=32)
|
1135
|
-
>>> x = jnp.zeros(10, dtype=env.dftype())
|
1136
|
-
>>> print(x.dtype) # float32
|
1137
|
-
>>>
|
1138
|
-
>>> # With 64-bit precision
|
1139
|
-
>>> with env.context(precision=64):
|
1140
|
-
... y = jnp.ones(5, dtype=env.dftype())
|
1141
|
-
... print(y.dtype) # float64
|
1142
|
-
>>>
|
1143
|
-
>>> # With bfloat16
|
1144
|
-
>>> env.set(precision='bf16')
|
1145
|
-
>>> z = jnp.array([1, 2, 3], dtype=env.dftype())
|
1146
|
-
>>> print(z.dtype) # bfloat16
|
1147
|
-
|
1148
|
-
See Also
|
1149
|
-
--------
|
1150
|
-
ditype : Default integer type
|
1151
|
-
dutype : Default unsigned integer type
|
1152
|
-
dctype : Default complex type
|
1153
|
-
"""
|
1154
|
-
return _get_float(_get_precision())
|
1155
|
-
|
1156
|
-
|
1157
|
-
def ditype() -> DTypeLike:
|
1158
|
-
"""
|
1159
|
-
Get the default integer data type.
|
1160
|
-
|
1161
|
-
This function returns the appropriate integer type based on
|
1162
|
-
the current precision setting.
|
1163
|
-
|
1164
|
-
Returns
|
1165
|
-
-------
|
1166
|
-
DTypeLike
|
1167
|
-
Default integer data type.
|
1168
|
-
|
1169
|
-
Examples
|
1170
|
-
--------
|
1171
|
-
.. code-block:: python
|
1172
|
-
|
1173
|
-
>>> import brainstate.environ as env
|
1174
|
-
>>> import jax.numpy as jnp
|
1175
|
-
>>>
|
1176
|
-
>>> # With 32-bit precision
|
1177
|
-
>>> env.set(precision=32)
|
1178
|
-
>>> indices = jnp.arange(10, dtype=env.ditype())
|
1179
|
-
>>> print(indices.dtype) # int32
|
1180
|
-
>>>
|
1181
|
-
>>> # With 64-bit precision
|
1182
|
-
>>> with env.context(precision=64):
|
1183
|
-
... big_indices = jnp.arange(1000, dtype=env.ditype())
|
1184
|
-
... print(big_indices.dtype) # int64
|
1185
|
-
|
1186
|
-
See Also
|
1187
|
-
--------
|
1188
|
-
dftype : Default floating-point type
|
1189
|
-
dutype : Default unsigned integer type
|
1190
|
-
"""
|
1191
|
-
return _get_int(_get_precision())
|
1192
|
-
|
1193
|
-
|
1194
|
-
def dutype() -> DTypeLike:
|
1195
|
-
"""
|
1196
|
-
Get the default unsigned integer data type.
|
1197
|
-
|
1198
|
-
This function returns the appropriate unsigned integer type based on
|
1199
|
-
the current precision setting.
|
1200
|
-
|
1201
|
-
Returns
|
1202
|
-
-------
|
1203
|
-
DTypeLike
|
1204
|
-
Default unsigned integer data type.
|
1205
|
-
|
1206
|
-
Examples
|
1207
|
-
--------
|
1208
|
-
.. code-block:: python
|
1209
|
-
|
1210
|
-
>>> import brainstate.environ as env
|
1211
|
-
>>> import jax.numpy as jnp
|
1212
|
-
>>>
|
1213
|
-
>>> # With 32-bit precision
|
1214
|
-
>>> env.set(precision=32)
|
1215
|
-
>>> counts = jnp.array([10, 20, 30], dtype=env.dutype())
|
1216
|
-
>>> print(counts.dtype) # uint32
|
1217
|
-
>>>
|
1218
|
-
>>> # With 16-bit precision
|
1219
|
-
>>> with env.context(precision=16):
|
1220
|
-
... small_counts = jnp.array([1, 2, 3], dtype=env.dutype())
|
1221
|
-
... print(small_counts.dtype) # uint16
|
1222
|
-
|
1223
|
-
See Also
|
1224
|
-
--------
|
1225
|
-
ditype : Default signed integer type
|
1226
|
-
"""
|
1227
|
-
return _get_uint(_get_precision())
|
1228
|
-
|
1229
|
-
|
1230
|
-
def dctype() -> DTypeLike:
|
1231
|
-
"""
|
1232
|
-
Get the default complex data type.
|
1233
|
-
|
1234
|
-
This function returns the appropriate complex type based on
|
1235
|
-
the current precision setting.
|
1236
|
-
|
1237
|
-
Returns
|
1238
|
-
-------
|
1239
|
-
DTypeLike
|
1240
|
-
Default complex data type.
|
1241
|
-
|
1242
|
-
Examples
|
1243
|
-
--------
|
1244
|
-
.. code-block:: python
|
1245
|
-
|
1246
|
-
>>> import brainstate.environ as env
|
1247
|
-
>>> import jax.numpy as jnp
|
1248
|
-
>>>
|
1249
|
-
>>> # With 32-bit precision
|
1250
|
-
>>> env.set(precision=32)
|
1251
|
-
>>> z = jnp.array([1+2j, 3+4j], dtype=env.dctype())
|
1252
|
-
>>> print(z.dtype) # complex64
|
1253
|
-
>>>
|
1254
|
-
>>> # With 64-bit precision
|
1255
|
-
>>> with env.context(precision=64):
|
1256
|
-
... w = jnp.array([5+6j], dtype=env.dctype())
|
1257
|
-
... print(w.dtype) # complex128
|
1258
|
-
|
1259
|
-
Notes
|
1260
|
-
-----
|
1261
|
-
Complex128 is only available with 64-bit precision.
|
1262
|
-
All other precisions use complex64.
|
1263
|
-
"""
|
1264
|
-
return _get_complex(_get_precision())
|
1265
|
-
|
1266
|
-
|
1267
|
-
def tolerance() -> jnp.ndarray:
|
1268
|
-
"""
|
1269
|
-
Get numerical tolerance based on current precision.
|
1270
|
-
|
1271
|
-
This function returns an appropriate tolerance value for numerical
|
1272
|
-
comparisons based on the current precision setting.
|
1273
|
-
|
1274
|
-
Returns
|
1275
|
-
-------
|
1276
|
-
jnp.ndarray
|
1277
|
-
Tolerance value as a scalar array.
|
1278
|
-
|
1279
|
-
Examples
|
1280
|
-
--------
|
1281
|
-
.. code-block:: python
|
1282
|
-
|
1283
|
-
>>> import brainstate.environ as env
|
1284
|
-
>>> import jax.numpy as jnp
|
1285
|
-
>>>
|
1286
|
-
>>> # Different tolerances for different precisions
|
1287
|
-
>>> env.set(precision=64)
|
1288
|
-
>>> tol64 = env.tolerance()
|
1289
|
-
>>> print(f"64-bit tolerance: {tol64}") # 1e-12
|
1290
|
-
>>>
|
1291
|
-
>>> env.set(precision=32)
|
1292
|
-
>>> tol32 = env.tolerance()
|
1293
|
-
>>> print(f"32-bit tolerance: {tol32}") # 1e-5
|
1294
|
-
>>>
|
1295
|
-
>>> # Use in numerical comparisons
|
1296
|
-
>>> def are_close(a, b):
|
1297
|
-
... return jnp.abs(a - b) < env.tolerance()
|
1298
|
-
|
1299
|
-
Notes
|
1300
|
-
-----
|
1301
|
-
Tolerance values:
|
1302
|
-
- 64-bit: 1e-12
|
1303
|
-
- 32-bit: 1e-5
|
1304
|
-
- 16-bit and below: 1e-2
|
1305
|
-
"""
|
1306
|
-
precision = get_precision()
|
1307
|
-
|
1308
|
-
if precision == 64:
|
1309
|
-
return jnp.array(1e-12, dtype=np.float64)
|
1310
|
-
elif precision == 32:
|
1311
|
-
return jnp.array(1e-5, dtype=np.float32)
|
1312
|
-
else:
|
1313
|
-
return jnp.array(1e-2, dtype=np.float16)
|
1314
|
-
|
1315
|
-
|
1316
|
-
def register_default_behavior(
|
1317
|
-
key: str,
|
1318
|
-
behavior: Callable[[Any], None],
|
1319
|
-
replace_if_exist: bool = False
|
1320
|
-
) -> None:
|
1321
|
-
"""
|
1322
|
-
Register a callback for environment parameter changes.
|
1323
|
-
|
1324
|
-
This function allows you to register custom behaviors that are
|
1325
|
-
triggered whenever a specific environment parameter is modified.
|
1326
|
-
|
1327
|
-
Parameters
|
1328
|
-
----------
|
1329
|
-
key : str
|
1330
|
-
Environment parameter key to monitor.
|
1331
|
-
behavior : Callable[[Any], None]
|
1332
|
-
Callback function that receives the new value.
|
1333
|
-
replace_if_exist : bool, default=False
|
1334
|
-
Whether to replace existing callback for this key.
|
1335
|
-
|
1336
|
-
Raises
|
1337
|
-
------
|
1338
|
-
TypeError
|
1339
|
-
If behavior is not callable.
|
1340
|
-
ValueError
|
1341
|
-
If key already has a registered behavior and replace_if_exist is False.
|
1342
|
-
|
1343
|
-
Examples
|
1344
|
-
--------
|
1345
|
-
Basic callback registration:
|
1346
|
-
|
1347
|
-
.. code-block:: python
|
1348
|
-
|
1349
|
-
>>> import brainstate.environ as env
|
1350
|
-
>>>
|
1351
|
-
>>> # Define a callback
|
1352
|
-
>>> def on_dt_change(new_dt):
|
1353
|
-
... print(f"Time step changed to: {new_dt}")
|
1354
|
-
>>>
|
1355
|
-
>>> # Register the callback
|
1356
|
-
>>> env.register_default_behavior('dt', on_dt_change)
|
1357
|
-
>>>
|
1358
|
-
>>> # Callback is triggered on changes
|
1359
|
-
>>> env.set(dt=0.01) # Prints: Time step changed to: 0.01
|
1360
|
-
>>>
|
1361
|
-
>>> with env.context(dt=0.001): # Prints: Time step changed to: 0.001
|
1362
|
-
... pass # Prints: Time step changed to: 0.01 (on exit)
|
1363
|
-
|
1364
|
-
Complex behavior with validation:
|
1365
|
-
|
1366
|
-
.. code-block:: python
|
1367
|
-
|
1368
|
-
>>> import brainstate.environ as env
|
1369
|
-
>>>
|
1370
|
-
>>> def validate_batch_size(size):
|
1371
|
-
... if not isinstance(size, int) or size <= 0:
|
1372
|
-
... raise ValueError(f"Invalid batch size: {size}")
|
1373
|
-
... if size > 1024:
|
1374
|
-
... print(f"Warning: Large batch size {size} may cause OOM")
|
1375
|
-
>>>
|
1376
|
-
>>> env.register_default_behavior('batch_size', validate_batch_size)
|
1377
|
-
>>>
|
1378
|
-
>>> # Valid setting
|
1379
|
-
>>> env.set(batch_size=32) # OK
|
1380
|
-
>>>
|
1381
|
-
>>> # Invalid setting
|
1382
|
-
>>> # env.set(batch_size=-1) # Raises ValueError
|
1383
|
-
|
1384
|
-
Replacing existing behavior:
|
1385
|
-
|
1386
|
-
.. code-block:: python
|
1387
|
-
|
1388
|
-
>>> import brainstate.environ as env
|
1389
|
-
>>>
|
1390
|
-
>>> def old_behavior(value):
|
1391
|
-
... print(f"Old: {value}")
|
1392
|
-
>>>
|
1393
|
-
>>> def new_behavior(value):
|
1394
|
-
... print(f"New: {value}")
|
1395
|
-
>>>
|
1396
|
-
>>> env.register_default_behavior('key', old_behavior)
|
1397
|
-
>>> env.register_default_behavior('key', new_behavior, replace_if_exist=True)
|
1398
|
-
>>>
|
1399
|
-
>>> env.set(key='test') # Prints: New: test
|
1400
|
-
|
1401
|
-
See Also
|
1402
|
-
--------
|
1403
|
-
unregister_default_behavior : Remove registered callbacks
|
1404
|
-
list_registered_behaviors : List all registered callbacks
|
1405
|
-
"""
|
1406
|
-
if not isinstance(key, str):
|
1407
|
-
raise TypeError(f"Key must be a string, got {type(key)}")
|
1408
|
-
|
1409
|
-
if not callable(behavior):
|
1410
|
-
raise TypeError(f"Behavior must be callable, got {type(behavior)}")
|
1411
|
-
|
1412
|
-
if key in _ENV_STATE.functions and not replace_if_exist:
|
1413
|
-
raise ValueError(
|
1414
|
-
f"Behavior for key '{key}' already registered. "
|
1415
|
-
f"Use replace_if_exist=True to override."
|
1416
|
-
)
|
1417
|
-
|
1418
|
-
_ENV_STATE.functions[key] = behavior
|
1419
|
-
|
1420
|
-
|
1421
|
-
def unregister_default_behavior(key: str) -> bool:
|
1422
|
-
"""
|
1423
|
-
Remove a registered callback for an environment parameter.
|
1424
|
-
|
1425
|
-
Parameters
|
1426
|
-
----------
|
1427
|
-
key : str
|
1428
|
-
Environment parameter key.
|
1429
|
-
|
1430
|
-
Returns
|
1431
|
-
-------
|
1432
|
-
bool
|
1433
|
-
True if a callback was removed, False if no callback existed.
|
1434
|
-
|
1435
|
-
Examples
|
1436
|
-
--------
|
1437
|
-
.. code-block:: python
|
1438
|
-
|
1439
|
-
>>> import brainstate.environ as env
|
1440
|
-
>>>
|
1441
|
-
>>> # Register a callback
|
1442
|
-
>>> def callback(value):
|
1443
|
-
... print(f"Value: {value}")
|
1444
|
-
>>>
|
1445
|
-
>>> env.register_default_behavior('param', callback)
|
1446
|
-
>>>
|
1447
|
-
>>> # Remove the callback
|
1448
|
-
>>> removed = env.unregister_default_behavior('param')
|
1449
|
-
>>> print(f"Callback removed: {removed}") # True
|
1450
|
-
>>>
|
1451
|
-
>>> # No callback triggers now
|
1452
|
-
>>> env.set(param='test') # No output
|
1453
|
-
>>>
|
1454
|
-
>>> # Removing non-existent callback
|
1455
|
-
>>> removed = env.unregister_default_behavior('nonexistent')
|
1456
|
-
>>> print(f"Callback removed: {removed}") # False
|
1457
|
-
"""
|
1458
|
-
if key in _ENV_STATE.functions:
|
1459
|
-
del _ENV_STATE.functions[key]
|
1460
|
-
return True
|
1461
|
-
return False
|
1462
|
-
|
1463
|
-
|
1464
|
-
def list_registered_behaviors() -> List[str]:
|
1465
|
-
"""
|
1466
|
-
List all keys with registered callbacks.
|
1467
|
-
|
1468
|
-
Returns
|
1469
|
-
-------
|
1470
|
-
list of str
|
1471
|
-
Keys that have registered behavior callbacks.
|
1472
|
-
|
1473
|
-
Examples
|
1474
|
-
--------
|
1475
|
-
.. code-block:: python
|
1476
|
-
|
1477
|
-
>>> import brainstate.environ as env
|
1478
|
-
>>>
|
1479
|
-
>>> # Register some callbacks
|
1480
|
-
>>> env.register_default_behavior('param1', lambda x: None)
|
1481
|
-
>>> env.register_default_behavior('param2', lambda x: None)
|
1482
|
-
>>>
|
1483
|
-
>>> # List registered behaviors
|
1484
|
-
>>> behaviors = env.list_registered_behaviors()
|
1485
|
-
>>> print(f"Registered: {behaviors}") # ['param1', 'param2']
|
1486
|
-
>>>
|
1487
|
-
>>> # Check if specific behavior is registered
|
1488
|
-
>>> if 'dt' in behaviors:
|
1489
|
-
... print("dt has a registered callback")
|
1490
|
-
"""
|
1491
|
-
return list(_ENV_STATE.functions.keys())
|
1492
|
-
|
1493
|
-
|
1494
|
-
# Initialize default precision on module load
|
1495
|
-
set(precision=DEFAULT_PRECISION)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Environment configuration and context management for BrainState.
|
18
|
+
|
19
|
+
This module provides comprehensive functionality for managing computational
|
20
|
+
environments, including platform selection, precision control, mode setting,
|
21
|
+
and context-based configuration management. It enables flexible configuration
|
22
|
+
of JAX-based computations with thread-safe context switching.
|
23
|
+
|
24
|
+
The module supports:
|
25
|
+
- Platform configuration (CPU, GPU, TPU)
|
26
|
+
- Precision control (8, 16, 32, 64 bit and bfloat16)
|
27
|
+
- Computation mode management
|
28
|
+
- Context-based temporary settings
|
29
|
+
- Default data type management
|
30
|
+
- Custom behavior registration
|
31
|
+
|
32
|
+
Examples
|
33
|
+
--------
|
34
|
+
Global environment configuration:
|
35
|
+
|
36
|
+
.. code-block:: python
|
37
|
+
|
38
|
+
>>> import brainstate as bs
|
39
|
+
>>> import brainstate.environ as env
|
40
|
+
>>>
|
41
|
+
>>> # Set global precision to 32-bit
|
42
|
+
>>> env.set(precision=32, dt=0.01, mode=bs.mixin.Training())
|
43
|
+
>>>
|
44
|
+
>>> # Get current settings
|
45
|
+
>>> print(env.get('precision')) # 32
|
46
|
+
>>> print(env.get('dt')) # 0.01
|
47
|
+
|
48
|
+
Context-based temporary settings:
|
49
|
+
|
50
|
+
.. code-block:: python
|
51
|
+
|
52
|
+
>>> import brainstate.environ as env
|
53
|
+
>>>
|
54
|
+
>>> # Temporarily change precision
|
55
|
+
>>> with env.context(precision=64, dt=0.001):
|
56
|
+
... high_precision_result = compute_something()
|
57
|
+
... print(env.get('precision')) # 64
|
58
|
+
>>> print(env.get('precision')) # Back to 32
|
59
|
+
"""
|
60
|
+
|
61
|
+
import contextlib
|
62
|
+
import dataclasses
|
63
|
+
import functools
|
64
|
+
import os
|
65
|
+
import re
|
66
|
+
import threading
|
67
|
+
import warnings
|
68
|
+
from collections import defaultdict
|
69
|
+
from typing import Any, Callable, Dict, Hashable, Optional, Union, ContextManager, List
|
70
|
+
|
71
|
+
import brainunit as u
|
72
|
+
import numpy as np
|
73
|
+
from jax import config, devices, numpy as jnp
|
74
|
+
from jax.typing import DTypeLike
|
75
|
+
|
76
|
+
__all__ = [
|
77
|
+
# Core environment management
|
78
|
+
'set',
|
79
|
+
'get',
|
80
|
+
'all',
|
81
|
+
'pop',
|
82
|
+
'context',
|
83
|
+
'reset',
|
84
|
+
|
85
|
+
# Platform and device management
|
86
|
+
'set_platform',
|
87
|
+
'get_platform',
|
88
|
+
'set_host_device_count',
|
89
|
+
'get_host_device_count',
|
90
|
+
|
91
|
+
# Precision and data type management
|
92
|
+
'set_precision',
|
93
|
+
'get_precision',
|
94
|
+
'dftype',
|
95
|
+
'ditype',
|
96
|
+
'dutype',
|
97
|
+
'dctype',
|
98
|
+
|
99
|
+
# Mode and computation settings
|
100
|
+
'get_dt',
|
101
|
+
|
102
|
+
# Utility functions
|
103
|
+
'tolerance',
|
104
|
+
'register_default_behavior',
|
105
|
+
'unregister_default_behavior',
|
106
|
+
'list_registered_behaviors',
|
107
|
+
|
108
|
+
# Constants
|
109
|
+
'DEFAULT_PRECISION',
|
110
|
+
'SUPPORTED_PLATFORMS',
|
111
|
+
'SUPPORTED_PRECISIONS',
|
112
|
+
|
113
|
+
# Names
|
114
|
+
'I',
|
115
|
+
'T',
|
116
|
+
'DT',
|
117
|
+
'PRECISION',
|
118
|
+
'PLATFORM',
|
119
|
+
'HOST_DEVICE_COUNT',
|
120
|
+
'JIT_ERROR_CHECK',
|
121
|
+
'FIT',
|
122
|
+
]
|
123
|
+
|
124
|
+
# Type definitions
|
125
|
+
# T = TypeVar('T')
|
126
|
+
PrecisionType = Union[int, str]
|
127
|
+
PlatformType = str
|
128
|
+
|
129
|
+
# Constants for environment keys
|
130
|
+
I = 'i' # Index of the current computation
|
131
|
+
T = 't' # Current time of the computation
|
132
|
+
DT = 'dt' # Time step for numerical integration
|
133
|
+
PRECISION = 'precision' # Numerical precision
|
134
|
+
PLATFORM = 'platform' # Computing platform
|
135
|
+
HOST_DEVICE_COUNT = 'host_device_count' # Number of host devices
|
136
|
+
JIT_ERROR_CHECK = 'jit_error_check' # JIT error checking flag
|
137
|
+
FIT = 'fit' # Model fitting flag
|
138
|
+
|
139
|
+
# Default values
|
140
|
+
DEFAULT_PRECISION = 32
|
141
|
+
SUPPORTED_PLATFORMS = ('cpu', 'gpu', 'tpu')
|
142
|
+
SUPPORTED_PRECISIONS = (8, 16, 32, 64, 'bf16')
|
143
|
+
|
144
|
+
# Sentinel value for missing arguments
|
145
|
+
_NOT_PROVIDED = object()
|
146
|
+
|
147
|
+
|
148
|
+
@dataclasses.dataclass
|
149
|
+
class EnvironmentState(threading.local):
|
150
|
+
"""
|
151
|
+
Thread-local storage for environment configuration.
|
152
|
+
|
153
|
+
This class maintains separate configuration states for different threads,
|
154
|
+
ensuring thread-safe environment management in concurrent applications.
|
155
|
+
|
156
|
+
Attributes
|
157
|
+
----------
|
158
|
+
settings : Dict[Hashable, Any]
|
159
|
+
Global default environment settings.
|
160
|
+
contexts : defaultdict[Hashable, List[Any]]
|
161
|
+
Stack of context-specific settings for nested contexts.
|
162
|
+
functions : Dict[Hashable, Callable]
|
163
|
+
Registered callback functions for environment changes.
|
164
|
+
locks : Dict[str, threading.Lock]
|
165
|
+
Thread locks for synchronized access to critical sections.
|
166
|
+
"""
|
167
|
+
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
|
168
|
+
contexts: defaultdict[Hashable, List[Any]] = dataclasses.field(default_factory=lambda: defaultdict(list))
|
169
|
+
functions: Dict[Hashable, Callable] = dataclasses.field(default_factory=dict)
|
170
|
+
locks: Dict[str, threading.Lock] = dataclasses.field(default_factory=lambda: defaultdict(threading.Lock))
|
171
|
+
|
172
|
+
def __post_init__(self):
|
173
|
+
"""Initialize with default settings."""
|
174
|
+
# Set default precision if not already set
|
175
|
+
if PRECISION not in self.settings:
|
176
|
+
self.settings[PRECISION] = DEFAULT_PRECISION
|
177
|
+
|
178
|
+
|
179
|
+
# Global environment state
|
180
|
+
_ENV_STATE = EnvironmentState()
|
181
|
+
|
182
|
+
|
183
|
+
def reset() -> None:
|
184
|
+
"""
|
185
|
+
Reset the environment to default settings.
|
186
|
+
|
187
|
+
This function clears all custom settings and restores the environment
|
188
|
+
to its initial state. Useful for testing or when starting fresh.
|
189
|
+
|
190
|
+
Examples
|
191
|
+
--------
|
192
|
+
.. code-block:: python
|
193
|
+
|
194
|
+
>>> import brainstate.environ as env
|
195
|
+
>>>
|
196
|
+
>>> # Set custom values
|
197
|
+
>>> env.set(dt=0.1, custom_param='value')
|
198
|
+
>>> print(env.get('custom_param')) # 'value'
|
199
|
+
>>>
|
200
|
+
>>> # Reset to defaults
|
201
|
+
>>> env.reset()
|
202
|
+
>>> print(env.get('custom_param', default=None)) # None
|
203
|
+
|
204
|
+
Notes
|
205
|
+
-----
|
206
|
+
This operation cannot be undone. All custom settings will be lost.
|
207
|
+
"""
|
208
|
+
global _ENV_STATE
|
209
|
+
_ENV_STATE = EnvironmentState()
|
210
|
+
# Re-apply default precision
|
211
|
+
_set_jax_precision(DEFAULT_PRECISION)
|
212
|
+
|
213
|
+
warnings.warn(
|
214
|
+
"Environment has been reset to default settings. "
|
215
|
+
"All custom configurations have been cleared.",
|
216
|
+
UserWarning
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
@contextlib.contextmanager
|
221
|
+
def context(**kwargs) -> ContextManager[Dict[str, Any]]:
|
222
|
+
"""
|
223
|
+
Context manager for temporary environment settings.
|
224
|
+
|
225
|
+
This context manager allows you to temporarily modify environment settings
|
226
|
+
within a specific scope. Settings are automatically restored when exiting
|
227
|
+
the context, even if an exception occurs.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
**kwargs
|
232
|
+
Environment settings to apply within the context.
|
233
|
+
Common parameters include:
|
234
|
+
|
235
|
+
- precision : int or str.
|
236
|
+
Numerical precision (8, 16, 32, 64, or 'bf16')
|
237
|
+
- dt : float.
|
238
|
+
Time step for numerical integration
|
239
|
+
- mode : :class:`Mode`.
|
240
|
+
Computation mode instance
|
241
|
+
- Any custom parameters registered via register_default_behavior
|
242
|
+
|
243
|
+
Yields
|
244
|
+
------
|
245
|
+
dict
|
246
|
+
Current environment settings within the context.
|
247
|
+
|
248
|
+
Raises
|
249
|
+
------
|
250
|
+
ValueError
|
251
|
+
If attempting to set platform or host_device_count in context
|
252
|
+
(these must be set globally).
|
253
|
+
TypeError
|
254
|
+
If invalid parameter types are provided.
|
255
|
+
|
256
|
+
Examples
|
257
|
+
--------
|
258
|
+
Basic usage with precision control:
|
259
|
+
|
260
|
+
.. code-block:: python
|
261
|
+
|
262
|
+
>>> import brainstate.environ as env
|
263
|
+
>>>
|
264
|
+
>>> # Set global precision
|
265
|
+
>>> env.set(precision=32)
|
266
|
+
>>>
|
267
|
+
>>> # Temporarily use higher precision
|
268
|
+
>>> with env.context(precision=64) as ctx:
|
269
|
+
... print(f"Precision in context: {env.get('precision')}") # 64
|
270
|
+
... print(f"Float type: {env.dftype()}") # float64
|
271
|
+
>>>
|
272
|
+
>>> print(f"Precision after context: {env.get('precision')}") # 32
|
273
|
+
|
274
|
+
Nested contexts:
|
275
|
+
|
276
|
+
.. code-block:: python
|
277
|
+
|
278
|
+
>>> import brainstate.environ as env
|
279
|
+
>>>
|
280
|
+
>>> with env.context(dt=0.1) as ctx1:
|
281
|
+
... print(f"dt = {env.get('dt')}") # 0.1
|
282
|
+
...
|
283
|
+
... with env.context(dt=0.01) as ctx2:
|
284
|
+
... print(f"dt = {env.get('dt')}") # 0.01
|
285
|
+
...
|
286
|
+
... print(f"dt = {env.get('dt')}") # 0.1
|
287
|
+
|
288
|
+
Error handling in context:
|
289
|
+
|
290
|
+
.. code-block:: python
|
291
|
+
|
292
|
+
>>> import brainstate.environ as env
|
293
|
+
>>>
|
294
|
+
>>> env.set(value=10)
|
295
|
+
>>> try:
|
296
|
+
... with env.context(value=20):
|
297
|
+
... print(env.get('value')) # 20
|
298
|
+
... raise ValueError("Something went wrong")
|
299
|
+
... except ValueError:
|
300
|
+
... pass
|
301
|
+
>>>
|
302
|
+
>>> print(env.get('value')) # 10 (restored)
|
303
|
+
|
304
|
+
Notes
|
305
|
+
-----
|
306
|
+
- Platform and host_device_count cannot be set in context
|
307
|
+
- Contexts can be nested arbitrarily deep
|
308
|
+
- Settings are restored in reverse order when exiting
|
309
|
+
- Thread-safe: each thread maintains its own context stack
|
310
|
+
"""
|
311
|
+
# Validate restricted parameters
|
312
|
+
if PLATFORM in kwargs:
|
313
|
+
raise ValueError(
|
314
|
+
f"Cannot set '{PLATFORM}' in context. "
|
315
|
+
f"Use set_platform() or set() for global configuration."
|
316
|
+
)
|
317
|
+
if HOST_DEVICE_COUNT in kwargs:
|
318
|
+
raise ValueError(
|
319
|
+
f"Cannot set '{HOST_DEVICE_COUNT}' in context. "
|
320
|
+
f"Use set_host_device_count() or set() for global configuration."
|
321
|
+
)
|
322
|
+
|
323
|
+
# Handle precision changes
|
324
|
+
original_precision = None
|
325
|
+
if PRECISION in kwargs:
|
326
|
+
original_precision = _get_precision()
|
327
|
+
_validate_precision(kwargs[PRECISION])
|
328
|
+
_set_jax_precision(kwargs[PRECISION])
|
329
|
+
|
330
|
+
try:
|
331
|
+
# Push new values onto context stacks
|
332
|
+
for key, value in kwargs.items():
|
333
|
+
with _ENV_STATE.locks[key]:
|
334
|
+
_ENV_STATE.contexts[key].append(value)
|
335
|
+
|
336
|
+
# Trigger registered callbacks
|
337
|
+
if key in _ENV_STATE.functions:
|
338
|
+
try:
|
339
|
+
_ENV_STATE.functions[key](value)
|
340
|
+
except Exception as e:
|
341
|
+
warnings.warn(
|
342
|
+
f"Callback for '{key}' raised an exception: {e}",
|
343
|
+
RuntimeWarning
|
344
|
+
)
|
345
|
+
|
346
|
+
# Yield current environment state
|
347
|
+
yield all()
|
348
|
+
|
349
|
+
finally:
|
350
|
+
# Restore previous values
|
351
|
+
for key in kwargs:
|
352
|
+
with _ENV_STATE.locks[key]:
|
353
|
+
if _ENV_STATE.contexts[key]:
|
354
|
+
_ENV_STATE.contexts[key].pop()
|
355
|
+
|
356
|
+
# Restore callbacks with previous value
|
357
|
+
if key in _ENV_STATE.functions:
|
358
|
+
try:
|
359
|
+
prev_value = get(key, default=None)
|
360
|
+
if prev_value is not None:
|
361
|
+
_ENV_STATE.functions[key](prev_value)
|
362
|
+
except Exception as e:
|
363
|
+
warnings.warn(
|
364
|
+
f"Callback restoration for '{key}' raised: {e}",
|
365
|
+
RuntimeWarning
|
366
|
+
)
|
367
|
+
|
368
|
+
# Restore precision if it was changed
|
369
|
+
if original_precision is not None:
|
370
|
+
_set_jax_precision(original_precision)
|
371
|
+
|
372
|
+
|
373
|
+
def get(key: str, default: Any = _NOT_PROVIDED, desc: Optional[str] = None) -> Any:
|
374
|
+
"""
|
375
|
+
Get a value from the current environment.
|
376
|
+
|
377
|
+
This function retrieves values from the environment, checking first in
|
378
|
+
the context stack, then in global settings. Special handling is provided
|
379
|
+
for platform and device count parameters.
|
380
|
+
|
381
|
+
Parameters
|
382
|
+
----------
|
383
|
+
key : str
|
384
|
+
The environment key to retrieve.
|
385
|
+
default : Any, optional
|
386
|
+
Default value to return if key is not found.
|
387
|
+
If not provided, raises KeyError for missing keys.
|
388
|
+
desc : str, optional
|
389
|
+
Description of the parameter for error messages.
|
390
|
+
|
391
|
+
Returns
|
392
|
+
-------
|
393
|
+
Any
|
394
|
+
The value associated with the key.
|
395
|
+
|
396
|
+
Raises
|
397
|
+
------
|
398
|
+
KeyError
|
399
|
+
If key is not found and no default is provided.
|
400
|
+
|
401
|
+
Examples
|
402
|
+
--------
|
403
|
+
Basic retrieval:
|
404
|
+
|
405
|
+
.. code-block:: python
|
406
|
+
|
407
|
+
>>> import brainstate.environ as env
|
408
|
+
>>>
|
409
|
+
>>> env.set(learning_rate=0.001)
|
410
|
+
>>> lr = env.get('learning_rate')
|
411
|
+
>>> print(f"Learning rate: {lr}") # 0.001
|
412
|
+
|
413
|
+
With default value:
|
414
|
+
|
415
|
+
.. code-block:: python
|
416
|
+
|
417
|
+
>>> import brainstate.environ as env
|
418
|
+
>>>
|
419
|
+
>>> # Get with default
|
420
|
+
>>> batch_size = env.get('batch_size', default=32)
|
421
|
+
>>> print(f"Batch size: {batch_size}") # 32
|
422
|
+
|
423
|
+
Context-aware retrieval:
|
424
|
+
|
425
|
+
.. code-block:: python
|
426
|
+
|
427
|
+
>>> import brainstate.environ as env
|
428
|
+
>>>
|
429
|
+
>>> env.set(temperature=1.0)
|
430
|
+
>>> print(env.get('temperature')) # 1.0
|
431
|
+
>>>
|
432
|
+
>>> with env.context(temperature=0.5):
|
433
|
+
... print(env.get('temperature')) # 0.5
|
434
|
+
>>>
|
435
|
+
>>> print(env.get('temperature')) # 1.0
|
436
|
+
|
437
|
+
Notes
|
438
|
+
-----
|
439
|
+
Special keys 'platform' and 'host_device_count' are handled separately
|
440
|
+
and retrieve system-level information.
|
441
|
+
"""
|
442
|
+
# Special cases for platform-specific parameters
|
443
|
+
if key == PLATFORM:
|
444
|
+
return get_platform()
|
445
|
+
if key == HOST_DEVICE_COUNT:
|
446
|
+
return get_host_device_count()
|
447
|
+
|
448
|
+
# Check context stack first (most recent value)
|
449
|
+
with _ENV_STATE.locks[key]:
|
450
|
+
if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
|
451
|
+
return _ENV_STATE.contexts[key][-1]
|
452
|
+
|
453
|
+
# Check global settings
|
454
|
+
if key in _ENV_STATE.settings:
|
455
|
+
return _ENV_STATE.settings[key]
|
456
|
+
|
457
|
+
# Handle missing key
|
458
|
+
if default is _NOT_PROVIDED:
|
459
|
+
error_msg = f"Key '{key}' not found in environment."
|
460
|
+
if desc:
|
461
|
+
error_msg += f" Description: {desc}"
|
462
|
+
error_msg += (
|
463
|
+
f"\nSet it using:\n"
|
464
|
+
f" - env.set({key}=value) for global setting\n"
|
465
|
+
f" - env.context({key}=value) for temporary setting"
|
466
|
+
)
|
467
|
+
raise KeyError(error_msg)
|
468
|
+
|
469
|
+
return default
|
470
|
+
|
471
|
+
|
472
|
+
def all() -> Dict[str, Any]:
|
473
|
+
"""
|
474
|
+
Get all current environment settings.
|
475
|
+
|
476
|
+
This function returns a dictionary containing all active environment
|
477
|
+
settings, with context values taking precedence over global settings.
|
478
|
+
|
479
|
+
Returns
|
480
|
+
-------
|
481
|
+
dict
|
482
|
+
Dictionary of all current environment settings.
|
483
|
+
|
484
|
+
Examples
|
485
|
+
--------
|
486
|
+
.. code-block:: python
|
487
|
+
|
488
|
+
>>> import brainstate.environ as env
|
489
|
+
>>>
|
490
|
+
>>> # Set various parameters
|
491
|
+
>>> env.set(precision=32, dt=0.01, debug=True)
|
492
|
+
>>>
|
493
|
+
>>> # Get all settings
|
494
|
+
>>> settings = env.all()
|
495
|
+
>>> print(settings)
|
496
|
+
{'precision': 32, 'dt': 0.01, 'debug': True}
|
497
|
+
|
498
|
+
>>> # Context overrides
|
499
|
+
>>> with env.context(precision=64, new_param='test'):
|
500
|
+
... settings = env.all()
|
501
|
+
... print(settings['precision']) # 64
|
502
|
+
... print(settings['new_param']) # 'test'
|
503
|
+
|
504
|
+
Notes
|
505
|
+
-----
|
506
|
+
The returned dictionary is a snapshot and modifying it does not
|
507
|
+
affect the environment settings.
|
508
|
+
"""
|
509
|
+
result = {}
|
510
|
+
|
511
|
+
# Add global settings
|
512
|
+
result.update(_ENV_STATE.settings)
|
513
|
+
|
514
|
+
# Override with context values (most recent)
|
515
|
+
for key, values in _ENV_STATE.contexts.items():
|
516
|
+
if values:
|
517
|
+
result[key] = values[-1]
|
518
|
+
|
519
|
+
return result
|
520
|
+
|
521
|
+
|
522
|
+
def pop(key: str, default: Any = _NOT_PROVIDED) -> Any:
|
523
|
+
"""
|
524
|
+
Remove and return a value from the global environment.
|
525
|
+
|
526
|
+
This function removes a key from the global environment settings and
|
527
|
+
returns its value. If the key is not found, it returns the default
|
528
|
+
value if provided, or raises KeyError.
|
529
|
+
|
530
|
+
Note that this function only affects global settings, not context values.
|
531
|
+
Keys in active contexts are not affected.
|
532
|
+
|
533
|
+
Parameters
|
534
|
+
----------
|
535
|
+
key : str
|
536
|
+
The environment key to remove.
|
537
|
+
default : Any, optional
|
538
|
+
Default value to return if key is not found.
|
539
|
+
If not provided, raises KeyError for missing keys.
|
540
|
+
|
541
|
+
Returns
|
542
|
+
-------
|
543
|
+
Any
|
544
|
+
The value that was removed from the environment.
|
545
|
+
|
546
|
+
Raises
|
547
|
+
------
|
548
|
+
KeyError
|
549
|
+
If key is not found and no default is provided.
|
550
|
+
ValueError
|
551
|
+
If attempting to pop a key that is currently in a context.
|
552
|
+
|
553
|
+
Examples
|
554
|
+
--------
|
555
|
+
Basic usage:
|
556
|
+
|
557
|
+
.. code-block:: python
|
558
|
+
|
559
|
+
>>> import brainstate.environ as env
|
560
|
+
>>>
|
561
|
+
>>> # Set a value
|
562
|
+
>>> env.set(temp_param='temporary')
|
563
|
+
>>> print(env.get('temp_param')) # 'temporary'
|
564
|
+
>>>
|
565
|
+
>>> # Pop the value
|
566
|
+
>>> value = env.pop('temp_param')
|
567
|
+
>>> print(value) # 'temporary'
|
568
|
+
>>>
|
569
|
+
>>> # Value is now gone
|
570
|
+
>>> env.get('temp_param', default=None) # None
|
571
|
+
|
572
|
+
With default value:
|
573
|
+
|
574
|
+
.. code-block:: python
|
575
|
+
|
576
|
+
>>> import brainstate.environ as env
|
577
|
+
>>>
|
578
|
+
>>> # Pop non-existent key with default
|
579
|
+
>>> value = env.pop('nonexistent', default='default_value')
|
580
|
+
>>> print(value) # 'default_value'
|
581
|
+
|
582
|
+
Pop multiple values:
|
583
|
+
|
584
|
+
.. code-block:: python
|
585
|
+
|
586
|
+
>>> import brainstate.environ as env
|
587
|
+
>>>
|
588
|
+
>>> # Set multiple values
|
589
|
+
>>> env.set(param1='value1', param2='value2', param3='value3')
|
590
|
+
>>>
|
591
|
+
>>> # Pop them one by one
|
592
|
+
>>> v1 = env.pop('param1')
|
593
|
+
>>> v2 = env.pop('param2')
|
594
|
+
>>>
|
595
|
+
>>> # param3 still exists
|
596
|
+
>>> print(env.get('param3')) # 'value3'
|
597
|
+
|
598
|
+
Context protection:
|
599
|
+
|
600
|
+
.. code-block:: python
|
601
|
+
|
602
|
+
>>> import brainstate.environ as env
|
603
|
+
>>>
|
604
|
+
>>> env.set(protected='global_value')
|
605
|
+
>>>
|
606
|
+
>>> with env.context(protected='context_value'):
|
607
|
+
... # Cannot pop a key that's in active context
|
608
|
+
... try:
|
609
|
+
... env.pop('protected')
|
610
|
+
... except ValueError as e:
|
611
|
+
... print("Cannot pop key in active context")
|
612
|
+
|
613
|
+
Notes
|
614
|
+
-----
|
615
|
+
- This function only removes keys from global settings
|
616
|
+
- Keys that are currently overridden in active contexts cannot be popped
|
617
|
+
- Special keys like 'platform' and 'host_device_count' can be popped but
|
618
|
+
their system-level values remain accessible through get_platform() etc.
|
619
|
+
- Registered callbacks are NOT triggered when popping values
|
620
|
+
"""
|
621
|
+
# Check if key is currently in any active context
|
622
|
+
if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
|
623
|
+
raise ValueError(
|
624
|
+
f"Cannot pop key '{key}' while it is active in a context. "
|
625
|
+
f"The key is currently overridden in {len(_ENV_STATE.contexts[key])} context(s)."
|
626
|
+
)
|
627
|
+
|
628
|
+
# Check if key exists in global settings
|
629
|
+
if key in _ENV_STATE.settings:
|
630
|
+
# Remove and return the value
|
631
|
+
value = _ENV_STATE.settings.pop(key)
|
632
|
+
|
633
|
+
# Note: We don't trigger callbacks here as this is a removal operation
|
634
|
+
# If needed, users can register callbacks for removal separately
|
635
|
+
|
636
|
+
return value
|
637
|
+
|
638
|
+
# Key not found, handle default
|
639
|
+
if default is _NOT_PROVIDED:
|
640
|
+
raise KeyError(f"Key '{key}' not found in global environment settings.")
|
641
|
+
|
642
|
+
return default
|
643
|
+
|
644
|
+
|
645
|
+
def set(
|
646
|
+
platform: Optional[PlatformType] = None,
|
647
|
+
host_device_count: Optional[int] = None,
|
648
|
+
precision: Optional[PrecisionType] = None,
|
649
|
+
dt: Optional[float] = None,
|
650
|
+
**kwargs
|
651
|
+
) -> None:
|
652
|
+
"""
|
653
|
+
Set global environment configuration.
|
654
|
+
|
655
|
+
This function sets persistent global environment settings that remain
|
656
|
+
active until explicitly changed or the program terminates.
|
657
|
+
|
658
|
+
Parameters
|
659
|
+
----------
|
660
|
+
platform : str, optional
|
661
|
+
Computing platform ('cpu', 'gpu', or 'tpu').
|
662
|
+
host_device_count : int, optional
|
663
|
+
Number of host devices for parallel computation.
|
664
|
+
precision : int or str, optional
|
665
|
+
Numerical precision (8, 16, 32, 64, or 'bf16').
|
666
|
+
mode : Mode, optional
|
667
|
+
Computation mode instance.
|
668
|
+
dt : float, optional
|
669
|
+
Time step for numerical integration.
|
670
|
+
**kwargs
|
671
|
+
Additional custom environment parameters.
|
672
|
+
|
673
|
+
Raises
|
674
|
+
------
|
675
|
+
ValueError
|
676
|
+
If invalid platform or precision is specified.
|
677
|
+
TypeError
|
678
|
+
If mode is not a Mode instance.
|
679
|
+
|
680
|
+
Examples
|
681
|
+
--------
|
682
|
+
Basic configuration:
|
683
|
+
|
684
|
+
.. code-block:: python
|
685
|
+
|
686
|
+
>>> import brainstate as bs
|
687
|
+
>>> import brainstate.environ as env
|
688
|
+
>>>
|
689
|
+
>>> # Set multiple parameters
|
690
|
+
>>> env.set(
|
691
|
+
... precision=32,
|
692
|
+
... dt=0.01,
|
693
|
+
... mode=bs.mixin.Training(),
|
694
|
+
... debug=False
|
695
|
+
... )
|
696
|
+
>>>
|
697
|
+
>>> print(env.get('precision')) # 32
|
698
|
+
>>> print(env.get('dt')) # 0.01
|
699
|
+
|
700
|
+
Platform configuration:
|
701
|
+
|
702
|
+
.. code-block:: python
|
703
|
+
|
704
|
+
>>> import brainstate.environ as env
|
705
|
+
>>>
|
706
|
+
>>> # Configure for GPU computation
|
707
|
+
>>> env.set(platform='gpu', precision=16)
|
708
|
+
>>>
|
709
|
+
>>> # Configure for multi-core CPU
|
710
|
+
>>> env.set(platform='cpu', host_device_count=4)
|
711
|
+
|
712
|
+
Custom parameters:
|
713
|
+
|
714
|
+
.. code-block:: python
|
715
|
+
|
716
|
+
>>> import brainstate.environ as env
|
717
|
+
>>>
|
718
|
+
>>> # Set custom parameters
|
719
|
+
>>> env.set(
|
720
|
+
... experiment_name='test_001',
|
721
|
+
... random_seed=42,
|
722
|
+
... log_level='DEBUG'
|
723
|
+
... )
|
724
|
+
>>>
|
725
|
+
>>> # Retrieve custom parameters
|
726
|
+
>>> print(env.get('experiment_name')) # 'test_001'
|
727
|
+
|
728
|
+
Notes
|
729
|
+
-----
|
730
|
+
- Platform changes only take effect at program start
|
731
|
+
- Some JAX configurations require restart to take effect
|
732
|
+
- Custom parameters can be any hashable key-value pairs
|
733
|
+
"""
|
734
|
+
# Handle special parameters
|
735
|
+
if platform is not None:
|
736
|
+
set_platform(platform)
|
737
|
+
|
738
|
+
if host_device_count is not None:
|
739
|
+
set_host_device_count(host_device_count)
|
740
|
+
|
741
|
+
if precision is not None:
|
742
|
+
_validate_precision(precision)
|
743
|
+
_set_jax_precision(precision)
|
744
|
+
kwargs[PRECISION] = precision
|
745
|
+
|
746
|
+
if dt is not None:
|
747
|
+
if not u.math.isscalar(dt):
|
748
|
+
raise TypeError(f"'{DT}' must be a scalar number, got {type(dt)}")
|
749
|
+
kwargs[DT] = dt
|
750
|
+
|
751
|
+
# Update global settings
|
752
|
+
_ENV_STATE.settings.update(kwargs)
|
753
|
+
|
754
|
+
# Trigger registered callbacks
|
755
|
+
for key, value in kwargs.items():
|
756
|
+
if key in _ENV_STATE.functions:
|
757
|
+
try:
|
758
|
+
_ENV_STATE.functions[key](value)
|
759
|
+
except Exception as e:
|
760
|
+
warnings.warn(
|
761
|
+
f"Callback for '{key}' raised an exception: {e}",
|
762
|
+
RuntimeWarning
|
763
|
+
)
|
764
|
+
|
765
|
+
|
766
|
+
def get_dt() -> float:
|
767
|
+
"""
|
768
|
+
Get the current numerical integration time step.
|
769
|
+
|
770
|
+
Returns
|
771
|
+
-------
|
772
|
+
float
|
773
|
+
The time step value.
|
774
|
+
|
775
|
+
Raises
|
776
|
+
------
|
777
|
+
KeyError
|
778
|
+
If dt is not set.
|
779
|
+
|
780
|
+
Examples
|
781
|
+
--------
|
782
|
+
.. code-block:: python
|
783
|
+
|
784
|
+
>>> import brainstate.environ as env
|
785
|
+
>>>
|
786
|
+
>>> env.set(dt=0.01)
|
787
|
+
>>> dt = env.get_dt()
|
788
|
+
>>> print(f"Time step: {dt} ms") # Time step: 0.01 ms
|
789
|
+
>>>
|
790
|
+
>>> # Use in computation
|
791
|
+
>>> with env.context(dt=0.001):
|
792
|
+
... fine_dt = env.get_dt()
|
793
|
+
... print(f"Fine time step: {fine_dt}") # 0.001
|
794
|
+
"""
|
795
|
+
return get(DT)
|
796
|
+
|
797
|
+
|
798
|
+
def get_platform() -> PlatformType:
|
799
|
+
"""
|
800
|
+
Get the current computing platform.
|
801
|
+
|
802
|
+
Returns
|
803
|
+
-------
|
804
|
+
str
|
805
|
+
Platform name ('cpu', 'gpu', or 'tpu').
|
806
|
+
|
807
|
+
Examples
|
808
|
+
--------
|
809
|
+
.. code-block:: python
|
810
|
+
|
811
|
+
>>> import brainstate.environ as env
|
812
|
+
>>>
|
813
|
+
>>> platform = env.get_platform()
|
814
|
+
>>> print(f"Running on: {platform}")
|
815
|
+
>>>
|
816
|
+
>>> if platform == 'gpu':
|
817
|
+
... print("GPU acceleration available")
|
818
|
+
... else:
|
819
|
+
... print(f"Using {platform.upper()}")
|
820
|
+
"""
|
821
|
+
return devices()[0].platform
|
822
|
+
|
823
|
+
|
824
|
+
def get_host_device_count() -> int:
|
825
|
+
"""
|
826
|
+
Get the number of host devices.
|
827
|
+
|
828
|
+
Returns
|
829
|
+
-------
|
830
|
+
int
|
831
|
+
Number of host devices configured.
|
832
|
+
|
833
|
+
Examples
|
834
|
+
--------
|
835
|
+
.. code-block:: python
|
836
|
+
|
837
|
+
>>> import brainstate.environ as env
|
838
|
+
>>>
|
839
|
+
>>> # Get device count
|
840
|
+
>>> n_devices = env.get_host_device_count()
|
841
|
+
>>> print(f"Host devices: {n_devices}")
|
842
|
+
>>>
|
843
|
+
>>> # Configure for parallel computation
|
844
|
+
>>> if n_devices > 1:
|
845
|
+
... print(f"Can use {n_devices} devices for parallel computation")
|
846
|
+
"""
|
847
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
848
|
+
match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
|
849
|
+
return int(match.group(1)) if match else 1
|
850
|
+
|
851
|
+
|
852
|
+
def set_platform(platform: PlatformType) -> None:
|
853
|
+
"""
|
854
|
+
Set the computing platform.
|
855
|
+
|
856
|
+
Parameters
|
857
|
+
----------
|
858
|
+
platform : str
|
859
|
+
Platform to use ('cpu', 'gpu', or 'tpu').
|
860
|
+
|
861
|
+
Raises
|
862
|
+
------
|
863
|
+
ValueError
|
864
|
+
If platform is not supported.
|
865
|
+
|
866
|
+
Examples
|
867
|
+
--------
|
868
|
+
.. code-block:: python
|
869
|
+
|
870
|
+
>>> import brainstate.environ as env
|
871
|
+
>>>
|
872
|
+
>>> # Set to GPU
|
873
|
+
>>> env.set_platform('gpu')
|
874
|
+
>>>
|
875
|
+
>>> # Verify platform
|
876
|
+
>>> print(env.get_platform()) # 'gpu'
|
877
|
+
|
878
|
+
Notes
|
879
|
+
-----
|
880
|
+
Platform changes only take effect at program start. Changing platform
|
881
|
+
after JAX initialization may not have the expected effect.
|
882
|
+
"""
|
883
|
+
if platform not in SUPPORTED_PLATFORMS:
|
884
|
+
raise ValueError(
|
885
|
+
f"Platform must be one of {SUPPORTED_PLATFORMS}, got '{platform}'"
|
886
|
+
)
|
887
|
+
|
888
|
+
config.update("jax_platform_name", platform)
|
889
|
+
|
890
|
+
# Trigger callbacks
|
891
|
+
if PLATFORM in _ENV_STATE.functions:
|
892
|
+
_ENV_STATE.functions[PLATFORM](platform)
|
893
|
+
|
894
|
+
|
895
|
+
def set_host_device_count(n: int) -> None:
|
896
|
+
"""
|
897
|
+
Set the number of host (CPU) devices.
|
898
|
+
|
899
|
+
This function configures XLA to treat CPU cores as separate devices,
|
900
|
+
enabling parallel computation with jax.pmap on CPU.
|
901
|
+
|
902
|
+
Parameters
|
903
|
+
----------
|
904
|
+
n : int
|
905
|
+
Number of host devices to configure.
|
906
|
+
|
907
|
+
Raises
|
908
|
+
------
|
909
|
+
ValueError
|
910
|
+
If n is not a positive integer.
|
911
|
+
|
912
|
+
Examples
|
913
|
+
--------
|
914
|
+
.. code-block:: python
|
915
|
+
|
916
|
+
>>> import brainstate.environ as env
|
917
|
+
>>> import jax
|
918
|
+
>>>
|
919
|
+
>>> # Configure 4 CPU devices
|
920
|
+
>>> env.set_host_device_count(4)
|
921
|
+
>>>
|
922
|
+
>>> # Use with pmap
|
923
|
+
>>> def parallel_fn(x):
|
924
|
+
... return x * 2
|
925
|
+
>>>
|
926
|
+
>>> # This will work with 4 devices
|
927
|
+
>>> pmapped_fn = jax.pmap(parallel_fn)
|
928
|
+
|
929
|
+
Warnings
|
930
|
+
--------
|
931
|
+
This setting only takes effect at program start. The effects of using
|
932
|
+
xla_force_host_platform_device_count are not fully understood and may
|
933
|
+
cause unexpected behavior.
|
934
|
+
"""
|
935
|
+
if not isinstance(n, int) or n < 1:
|
936
|
+
raise ValueError(f"Host device count must be a positive integer, got {n}")
|
937
|
+
|
938
|
+
# Update XLA flags
|
939
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
940
|
+
xla_flags = re.sub(
|
941
|
+
r"--xla_force_host_platform_device_count=\S+",
|
942
|
+
"",
|
943
|
+
xla_flags
|
944
|
+
).split()
|
945
|
+
|
946
|
+
os.environ["XLA_FLAGS"] = " ".join(
|
947
|
+
[f"--xla_force_host_platform_device_count={n}"] + xla_flags
|
948
|
+
)
|
949
|
+
|
950
|
+
# Trigger callbacks
|
951
|
+
if HOST_DEVICE_COUNT in _ENV_STATE.functions:
|
952
|
+
_ENV_STATE.functions[HOST_DEVICE_COUNT](n)
|
953
|
+
|
954
|
+
|
955
|
+
def set_precision(precision: PrecisionType) -> None:
|
956
|
+
"""
|
957
|
+
Set the global numerical precision.
|
958
|
+
|
959
|
+
Parameters
|
960
|
+
----------
|
961
|
+
precision : int or str
|
962
|
+
Precision to use (8, 16, 32, 64, or 'bf16').
|
963
|
+
|
964
|
+
Raises
|
965
|
+
------
|
966
|
+
ValueError
|
967
|
+
If precision is not supported.
|
968
|
+
|
969
|
+
Examples
|
970
|
+
--------
|
971
|
+
.. code-block:: python
|
972
|
+
|
973
|
+
>>> import brainstate.environ as env
|
974
|
+
>>> import jax.numpy as jnp
|
975
|
+
>>>
|
976
|
+
>>> # Set to 64-bit precision
|
977
|
+
>>> env.set_precision(64)
|
978
|
+
>>>
|
979
|
+
>>> # Arrays will use float64 by default
|
980
|
+
>>> x = jnp.array([1.0, 2.0, 3.0])
|
981
|
+
>>> print(x.dtype) # float64
|
982
|
+
>>>
|
983
|
+
>>> # Set to bfloat16 for efficiency
|
984
|
+
>>> env.set_precision('bf16')
|
985
|
+
"""
|
986
|
+
_validate_precision(precision)
|
987
|
+
_set_jax_precision(precision)
|
988
|
+
_ENV_STATE.settings[PRECISION] = precision
|
989
|
+
|
990
|
+
# Trigger callbacks
|
991
|
+
if PRECISION in _ENV_STATE.functions:
|
992
|
+
_ENV_STATE.functions[PRECISION](precision)
|
993
|
+
|
994
|
+
|
995
|
+
def get_precision() -> int:
|
996
|
+
"""
|
997
|
+
Get the current numerical precision as an integer.
|
998
|
+
|
999
|
+
Returns
|
1000
|
+
-------
|
1001
|
+
int
|
1002
|
+
Precision in bits (8, 16, 32, or 64).
|
1003
|
+
|
1004
|
+
Examples
|
1005
|
+
--------
|
1006
|
+
.. code-block:: python
|
1007
|
+
|
1008
|
+
>>> import brainstate.environ as env
|
1009
|
+
>>>
|
1010
|
+
>>> env.set_precision(32)
|
1011
|
+
>>> bits = env.get_precision()
|
1012
|
+
>>> print(f"Using {bits}-bit precision") # Using 32-bit precision
|
1013
|
+
>>>
|
1014
|
+
>>> # Special handling for bfloat16
|
1015
|
+
>>> env.set_precision('bf16')
|
1016
|
+
>>> print(env.get_precision()) # 16
|
1017
|
+
|
1018
|
+
Notes
|
1019
|
+
-----
|
1020
|
+
'bf16' (bfloat16) is reported as 16-bit precision.
|
1021
|
+
"""
|
1022
|
+
precision = get(PRECISION, default=DEFAULT_PRECISION)
|
1023
|
+
|
1024
|
+
if precision == 'bf16':
|
1025
|
+
return 16
|
1026
|
+
elif isinstance(precision, str):
|
1027
|
+
return int(precision)
|
1028
|
+
elif isinstance(precision, int):
|
1029
|
+
return precision
|
1030
|
+
else:
|
1031
|
+
raise ValueError(f"Invalid precision type: {type(precision)}")
|
1032
|
+
|
1033
|
+
|
1034
|
+
def _validate_precision(precision: PrecisionType) -> None:
|
1035
|
+
"""Validate precision value."""
|
1036
|
+
if precision not in SUPPORTED_PRECISIONS and str(precision) not in map(str, SUPPORTED_PRECISIONS):
|
1037
|
+
raise ValueError(
|
1038
|
+
f"Precision must be one of {SUPPORTED_PRECISIONS}, got {precision}"
|
1039
|
+
)
|
1040
|
+
|
1041
|
+
|
1042
|
+
def _get_precision() -> PrecisionType:
|
1043
|
+
"""Get raw precision value (including 'bf16')."""
|
1044
|
+
return get(PRECISION, default=DEFAULT_PRECISION)
|
1045
|
+
|
1046
|
+
|
1047
|
+
def _set_jax_precision(precision: PrecisionType) -> None:
|
1048
|
+
"""Configure JAX precision settings."""
|
1049
|
+
# Enable/disable 64-bit mode
|
1050
|
+
if precision in (64, '64'):
|
1051
|
+
config.update("jax_enable_x64", True)
|
1052
|
+
else:
|
1053
|
+
config.update("jax_enable_x64", False)
|
1054
|
+
|
1055
|
+
|
1056
|
+
@functools.lru_cache(maxsize=16)
|
1057
|
+
def _get_uint(precision: PrecisionType) -> DTypeLike:
|
1058
|
+
"""Get unsigned integer type for given precision."""
|
1059
|
+
if precision in (64, '64'):
|
1060
|
+
return np.uint64
|
1061
|
+
elif precision in (32, '32'):
|
1062
|
+
return np.uint32
|
1063
|
+
elif precision in (16, '16', 'bf16'):
|
1064
|
+
return np.uint16
|
1065
|
+
elif precision in (8, '8'):
|
1066
|
+
return np.uint8
|
1067
|
+
else:
|
1068
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
1069
|
+
|
1070
|
+
|
1071
|
+
@functools.lru_cache(maxsize=16)
|
1072
|
+
def _get_int(precision: PrecisionType) -> DTypeLike:
|
1073
|
+
"""Get integer type for given precision."""
|
1074
|
+
if precision in (64, '64'):
|
1075
|
+
return np.int64
|
1076
|
+
elif precision in (32, '32'):
|
1077
|
+
return np.int32
|
1078
|
+
elif precision in (16, '16', 'bf16'):
|
1079
|
+
return np.int16
|
1080
|
+
elif precision in (8, '8'):
|
1081
|
+
return np.int8
|
1082
|
+
else:
|
1083
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
1084
|
+
|
1085
|
+
|
1086
|
+
@functools.lru_cache(maxsize=16)
|
1087
|
+
def _get_float(precision: PrecisionType) -> DTypeLike:
|
1088
|
+
"""Get floating-point type for given precision."""
|
1089
|
+
if precision in (64, '64'):
|
1090
|
+
return np.float64
|
1091
|
+
elif precision in (32, '32'):
|
1092
|
+
return np.float32
|
1093
|
+
elif precision in (16, '16'):
|
1094
|
+
return np.float16
|
1095
|
+
elif precision == 'bf16':
|
1096
|
+
return jnp.bfloat16
|
1097
|
+
elif precision in (8, '8'):
|
1098
|
+
return jnp.float8_e5m2
|
1099
|
+
else:
|
1100
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
1101
|
+
|
1102
|
+
|
1103
|
+
@functools.lru_cache(maxsize=16)
|
1104
|
+
def _get_complex(precision: PrecisionType) -> DTypeLike:
|
1105
|
+
"""Get complex type for given precision."""
|
1106
|
+
if precision in (64, '64'):
|
1107
|
+
return np.complex128
|
1108
|
+
elif precision in (32, '32', 16, '16', 'bf16', 8, '8'):
|
1109
|
+
return np.complex64
|
1110
|
+
else:
|
1111
|
+
raise ValueError(f"Unsupported precision: {precision}")
|
1112
|
+
|
1113
|
+
|
1114
|
+
def dftype() -> DTypeLike:
|
1115
|
+
"""
|
1116
|
+
Get the default floating-point data type.
|
1117
|
+
|
1118
|
+
This function returns the appropriate floating-point type based on
|
1119
|
+
the current precision setting, allowing dynamic type selection.
|
1120
|
+
|
1121
|
+
Returns
|
1122
|
+
-------
|
1123
|
+
DTypeLike
|
1124
|
+
Default floating-point data type.
|
1125
|
+
|
1126
|
+
Examples
|
1127
|
+
--------
|
1128
|
+
.. code-block:: python
|
1129
|
+
|
1130
|
+
>>> import brainstate.environ as env
|
1131
|
+
>>> import jax.numpy as jnp
|
1132
|
+
>>>
|
1133
|
+
>>> # With 32-bit precision
|
1134
|
+
>>> env.set(precision=32)
|
1135
|
+
>>> x = jnp.zeros(10, dtype=env.dftype())
|
1136
|
+
>>> print(x.dtype) # float32
|
1137
|
+
>>>
|
1138
|
+
>>> # With 64-bit precision
|
1139
|
+
>>> with env.context(precision=64):
|
1140
|
+
... y = jnp.ones(5, dtype=env.dftype())
|
1141
|
+
... print(y.dtype) # float64
|
1142
|
+
>>>
|
1143
|
+
>>> # With bfloat16
|
1144
|
+
>>> env.set(precision='bf16')
|
1145
|
+
>>> z = jnp.array([1, 2, 3], dtype=env.dftype())
|
1146
|
+
>>> print(z.dtype) # bfloat16
|
1147
|
+
|
1148
|
+
See Also
|
1149
|
+
--------
|
1150
|
+
ditype : Default integer type
|
1151
|
+
dutype : Default unsigned integer type
|
1152
|
+
dctype : Default complex type
|
1153
|
+
"""
|
1154
|
+
return _get_float(_get_precision())
|
1155
|
+
|
1156
|
+
|
1157
|
+
def ditype() -> DTypeLike:
|
1158
|
+
"""
|
1159
|
+
Get the default integer data type.
|
1160
|
+
|
1161
|
+
This function returns the appropriate integer type based on
|
1162
|
+
the current precision setting.
|
1163
|
+
|
1164
|
+
Returns
|
1165
|
+
-------
|
1166
|
+
DTypeLike
|
1167
|
+
Default integer data type.
|
1168
|
+
|
1169
|
+
Examples
|
1170
|
+
--------
|
1171
|
+
.. code-block:: python
|
1172
|
+
|
1173
|
+
>>> import brainstate.environ as env
|
1174
|
+
>>> import jax.numpy as jnp
|
1175
|
+
>>>
|
1176
|
+
>>> # With 32-bit precision
|
1177
|
+
>>> env.set(precision=32)
|
1178
|
+
>>> indices = jnp.arange(10, dtype=env.ditype())
|
1179
|
+
>>> print(indices.dtype) # int32
|
1180
|
+
>>>
|
1181
|
+
>>> # With 64-bit precision
|
1182
|
+
>>> with env.context(precision=64):
|
1183
|
+
... big_indices = jnp.arange(1000, dtype=env.ditype())
|
1184
|
+
... print(big_indices.dtype) # int64
|
1185
|
+
|
1186
|
+
See Also
|
1187
|
+
--------
|
1188
|
+
dftype : Default floating-point type
|
1189
|
+
dutype : Default unsigned integer type
|
1190
|
+
"""
|
1191
|
+
return _get_int(_get_precision())
|
1192
|
+
|
1193
|
+
|
1194
|
+
def dutype() -> DTypeLike:
|
1195
|
+
"""
|
1196
|
+
Get the default unsigned integer data type.
|
1197
|
+
|
1198
|
+
This function returns the appropriate unsigned integer type based on
|
1199
|
+
the current precision setting.
|
1200
|
+
|
1201
|
+
Returns
|
1202
|
+
-------
|
1203
|
+
DTypeLike
|
1204
|
+
Default unsigned integer data type.
|
1205
|
+
|
1206
|
+
Examples
|
1207
|
+
--------
|
1208
|
+
.. code-block:: python
|
1209
|
+
|
1210
|
+
>>> import brainstate.environ as env
|
1211
|
+
>>> import jax.numpy as jnp
|
1212
|
+
>>>
|
1213
|
+
>>> # With 32-bit precision
|
1214
|
+
>>> env.set(precision=32)
|
1215
|
+
>>> counts = jnp.array([10, 20, 30], dtype=env.dutype())
|
1216
|
+
>>> print(counts.dtype) # uint32
|
1217
|
+
>>>
|
1218
|
+
>>> # With 16-bit precision
|
1219
|
+
>>> with env.context(precision=16):
|
1220
|
+
... small_counts = jnp.array([1, 2, 3], dtype=env.dutype())
|
1221
|
+
... print(small_counts.dtype) # uint16
|
1222
|
+
|
1223
|
+
See Also
|
1224
|
+
--------
|
1225
|
+
ditype : Default signed integer type
|
1226
|
+
"""
|
1227
|
+
return _get_uint(_get_precision())
|
1228
|
+
|
1229
|
+
|
1230
|
+
def dctype() -> DTypeLike:
|
1231
|
+
"""
|
1232
|
+
Get the default complex data type.
|
1233
|
+
|
1234
|
+
This function returns the appropriate complex type based on
|
1235
|
+
the current precision setting.
|
1236
|
+
|
1237
|
+
Returns
|
1238
|
+
-------
|
1239
|
+
DTypeLike
|
1240
|
+
Default complex data type.
|
1241
|
+
|
1242
|
+
Examples
|
1243
|
+
--------
|
1244
|
+
.. code-block:: python
|
1245
|
+
|
1246
|
+
>>> import brainstate.environ as env
|
1247
|
+
>>> import jax.numpy as jnp
|
1248
|
+
>>>
|
1249
|
+
>>> # With 32-bit precision
|
1250
|
+
>>> env.set(precision=32)
|
1251
|
+
>>> z = jnp.array([1+2j, 3+4j], dtype=env.dctype())
|
1252
|
+
>>> print(z.dtype) # complex64
|
1253
|
+
>>>
|
1254
|
+
>>> # With 64-bit precision
|
1255
|
+
>>> with env.context(precision=64):
|
1256
|
+
... w = jnp.array([5+6j], dtype=env.dctype())
|
1257
|
+
... print(w.dtype) # complex128
|
1258
|
+
|
1259
|
+
Notes
|
1260
|
+
-----
|
1261
|
+
Complex128 is only available with 64-bit precision.
|
1262
|
+
All other precisions use complex64.
|
1263
|
+
"""
|
1264
|
+
return _get_complex(_get_precision())
|
1265
|
+
|
1266
|
+
|
1267
|
+
def tolerance() -> jnp.ndarray:
|
1268
|
+
"""
|
1269
|
+
Get numerical tolerance based on current precision.
|
1270
|
+
|
1271
|
+
This function returns an appropriate tolerance value for numerical
|
1272
|
+
comparisons based on the current precision setting.
|
1273
|
+
|
1274
|
+
Returns
|
1275
|
+
-------
|
1276
|
+
jnp.ndarray
|
1277
|
+
Tolerance value as a scalar array.
|
1278
|
+
|
1279
|
+
Examples
|
1280
|
+
--------
|
1281
|
+
.. code-block:: python
|
1282
|
+
|
1283
|
+
>>> import brainstate.environ as env
|
1284
|
+
>>> import jax.numpy as jnp
|
1285
|
+
>>>
|
1286
|
+
>>> # Different tolerances for different precisions
|
1287
|
+
>>> env.set(precision=64)
|
1288
|
+
>>> tol64 = env.tolerance()
|
1289
|
+
>>> print(f"64-bit tolerance: {tol64}") # 1e-12
|
1290
|
+
>>>
|
1291
|
+
>>> env.set(precision=32)
|
1292
|
+
>>> tol32 = env.tolerance()
|
1293
|
+
>>> print(f"32-bit tolerance: {tol32}") # 1e-5
|
1294
|
+
>>>
|
1295
|
+
>>> # Use in numerical comparisons
|
1296
|
+
>>> def are_close(a, b):
|
1297
|
+
... return jnp.abs(a - b) < env.tolerance()
|
1298
|
+
|
1299
|
+
Notes
|
1300
|
+
-----
|
1301
|
+
Tolerance values:
|
1302
|
+
- 64-bit: 1e-12
|
1303
|
+
- 32-bit: 1e-5
|
1304
|
+
- 16-bit and below: 1e-2
|
1305
|
+
"""
|
1306
|
+
precision = get_precision()
|
1307
|
+
|
1308
|
+
if precision == 64:
|
1309
|
+
return jnp.array(1e-12, dtype=np.float64)
|
1310
|
+
elif precision == 32:
|
1311
|
+
return jnp.array(1e-5, dtype=np.float32)
|
1312
|
+
else:
|
1313
|
+
return jnp.array(1e-2, dtype=np.float16)
|
1314
|
+
|
1315
|
+
|
1316
|
+
def register_default_behavior(
|
1317
|
+
key: str,
|
1318
|
+
behavior: Callable[[Any], None],
|
1319
|
+
replace_if_exist: bool = False
|
1320
|
+
) -> None:
|
1321
|
+
"""
|
1322
|
+
Register a callback for environment parameter changes.
|
1323
|
+
|
1324
|
+
This function allows you to register custom behaviors that are
|
1325
|
+
triggered whenever a specific environment parameter is modified.
|
1326
|
+
|
1327
|
+
Parameters
|
1328
|
+
----------
|
1329
|
+
key : str
|
1330
|
+
Environment parameter key to monitor.
|
1331
|
+
behavior : Callable[[Any], None]
|
1332
|
+
Callback function that receives the new value.
|
1333
|
+
replace_if_exist : bool, default=False
|
1334
|
+
Whether to replace existing callback for this key.
|
1335
|
+
|
1336
|
+
Raises
|
1337
|
+
------
|
1338
|
+
TypeError
|
1339
|
+
If behavior is not callable.
|
1340
|
+
ValueError
|
1341
|
+
If key already has a registered behavior and replace_if_exist is False.
|
1342
|
+
|
1343
|
+
Examples
|
1344
|
+
--------
|
1345
|
+
Basic callback registration:
|
1346
|
+
|
1347
|
+
.. code-block:: python
|
1348
|
+
|
1349
|
+
>>> import brainstate.environ as env
|
1350
|
+
>>>
|
1351
|
+
>>> # Define a callback
|
1352
|
+
>>> def on_dt_change(new_dt):
|
1353
|
+
... print(f"Time step changed to: {new_dt}")
|
1354
|
+
>>>
|
1355
|
+
>>> # Register the callback
|
1356
|
+
>>> env.register_default_behavior('dt', on_dt_change)
|
1357
|
+
>>>
|
1358
|
+
>>> # Callback is triggered on changes
|
1359
|
+
>>> env.set(dt=0.01) # Prints: Time step changed to: 0.01
|
1360
|
+
>>>
|
1361
|
+
>>> with env.context(dt=0.001): # Prints: Time step changed to: 0.001
|
1362
|
+
... pass # Prints: Time step changed to: 0.01 (on exit)
|
1363
|
+
|
1364
|
+
Complex behavior with validation:
|
1365
|
+
|
1366
|
+
.. code-block:: python
|
1367
|
+
|
1368
|
+
>>> import brainstate.environ as env
|
1369
|
+
>>>
|
1370
|
+
>>> def validate_batch_size(size):
|
1371
|
+
... if not isinstance(size, int) or size <= 0:
|
1372
|
+
... raise ValueError(f"Invalid batch size: {size}")
|
1373
|
+
... if size > 1024:
|
1374
|
+
... print(f"Warning: Large batch size {size} may cause OOM")
|
1375
|
+
>>>
|
1376
|
+
>>> env.register_default_behavior('batch_size', validate_batch_size)
|
1377
|
+
>>>
|
1378
|
+
>>> # Valid setting
|
1379
|
+
>>> env.set(batch_size=32) # OK
|
1380
|
+
>>>
|
1381
|
+
>>> # Invalid setting
|
1382
|
+
>>> # env.set(batch_size=-1) # Raises ValueError
|
1383
|
+
|
1384
|
+
Replacing existing behavior:
|
1385
|
+
|
1386
|
+
.. code-block:: python
|
1387
|
+
|
1388
|
+
>>> import brainstate.environ as env
|
1389
|
+
>>>
|
1390
|
+
>>> def old_behavior(value):
|
1391
|
+
... print(f"Old: {value}")
|
1392
|
+
>>>
|
1393
|
+
>>> def new_behavior(value):
|
1394
|
+
... print(f"New: {value}")
|
1395
|
+
>>>
|
1396
|
+
>>> env.register_default_behavior('key', old_behavior)
|
1397
|
+
>>> env.register_default_behavior('key', new_behavior, replace_if_exist=True)
|
1398
|
+
>>>
|
1399
|
+
>>> env.set(key='test') # Prints: New: test
|
1400
|
+
|
1401
|
+
See Also
|
1402
|
+
--------
|
1403
|
+
unregister_default_behavior : Remove registered callbacks
|
1404
|
+
list_registered_behaviors : List all registered callbacks
|
1405
|
+
"""
|
1406
|
+
if not isinstance(key, str):
|
1407
|
+
raise TypeError(f"Key must be a string, got {type(key)}")
|
1408
|
+
|
1409
|
+
if not callable(behavior):
|
1410
|
+
raise TypeError(f"Behavior must be callable, got {type(behavior)}")
|
1411
|
+
|
1412
|
+
if key in _ENV_STATE.functions and not replace_if_exist:
|
1413
|
+
raise ValueError(
|
1414
|
+
f"Behavior for key '{key}' already registered. "
|
1415
|
+
f"Use replace_if_exist=True to override."
|
1416
|
+
)
|
1417
|
+
|
1418
|
+
_ENV_STATE.functions[key] = behavior
|
1419
|
+
|
1420
|
+
|
1421
|
+
def unregister_default_behavior(key: str) -> bool:
|
1422
|
+
"""
|
1423
|
+
Remove a registered callback for an environment parameter.
|
1424
|
+
|
1425
|
+
Parameters
|
1426
|
+
----------
|
1427
|
+
key : str
|
1428
|
+
Environment parameter key.
|
1429
|
+
|
1430
|
+
Returns
|
1431
|
+
-------
|
1432
|
+
bool
|
1433
|
+
True if a callback was removed, False if no callback existed.
|
1434
|
+
|
1435
|
+
Examples
|
1436
|
+
--------
|
1437
|
+
.. code-block:: python
|
1438
|
+
|
1439
|
+
>>> import brainstate.environ as env
|
1440
|
+
>>>
|
1441
|
+
>>> # Register a callback
|
1442
|
+
>>> def callback(value):
|
1443
|
+
... print(f"Value: {value}")
|
1444
|
+
>>>
|
1445
|
+
>>> env.register_default_behavior('param', callback)
|
1446
|
+
>>>
|
1447
|
+
>>> # Remove the callback
|
1448
|
+
>>> removed = env.unregister_default_behavior('param')
|
1449
|
+
>>> print(f"Callback removed: {removed}") # True
|
1450
|
+
>>>
|
1451
|
+
>>> # No callback triggers now
|
1452
|
+
>>> env.set(param='test') # No output
|
1453
|
+
>>>
|
1454
|
+
>>> # Removing non-existent callback
|
1455
|
+
>>> removed = env.unregister_default_behavior('nonexistent')
|
1456
|
+
>>> print(f"Callback removed: {removed}") # False
|
1457
|
+
"""
|
1458
|
+
if key in _ENV_STATE.functions:
|
1459
|
+
del _ENV_STATE.functions[key]
|
1460
|
+
return True
|
1461
|
+
return False
|
1462
|
+
|
1463
|
+
|
1464
|
+
def list_registered_behaviors() -> List[str]:
|
1465
|
+
"""
|
1466
|
+
List all keys with registered callbacks.
|
1467
|
+
|
1468
|
+
Returns
|
1469
|
+
-------
|
1470
|
+
list of str
|
1471
|
+
Keys that have registered behavior callbacks.
|
1472
|
+
|
1473
|
+
Examples
|
1474
|
+
--------
|
1475
|
+
.. code-block:: python
|
1476
|
+
|
1477
|
+
>>> import brainstate.environ as env
|
1478
|
+
>>>
|
1479
|
+
>>> # Register some callbacks
|
1480
|
+
>>> env.register_default_behavior('param1', lambda x: None)
|
1481
|
+
>>> env.register_default_behavior('param2', lambda x: None)
|
1482
|
+
>>>
|
1483
|
+
>>> # List registered behaviors
|
1484
|
+
>>> behaviors = env.list_registered_behaviors()
|
1485
|
+
>>> print(f"Registered: {behaviors}") # ['param1', 'param2']
|
1486
|
+
>>>
|
1487
|
+
>>> # Check if specific behavior is registered
|
1488
|
+
>>> if 'dt' in behaviors:
|
1489
|
+
... print("dt has a registered callback")
|
1490
|
+
"""
|
1491
|
+
return list(_ENV_STATE.functions.keys())
|
1492
|
+
|
1493
|
+
|
1494
|
+
# Initialize default precision on module load
|
1495
|
+
set(precision=DEFAULT_PRECISION)
|