brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/environ.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,7 +13,50 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- # -*- coding: utf-8 -*-
16
+ """
17
+ Environment configuration and context management for BrainState.
18
+
19
+ This module provides comprehensive functionality for managing computational
20
+ environments, including platform selection, precision control, mode setting,
21
+ and context-based configuration management. It enables flexible configuration
22
+ of JAX-based computations with thread-safe context switching.
23
+
24
+ The module supports:
25
+ - Platform configuration (CPU, GPU, TPU)
26
+ - Precision control (8, 16, 32, 64 bit and bfloat16)
27
+ - Computation mode management
28
+ - Context-based temporary settings
29
+ - Default data type management
30
+ - Custom behavior registration
31
+
32
+ Examples
33
+ --------
34
+ Global environment configuration:
35
+
36
+ .. code-block:: python
37
+
38
+ >>> import brainstate as bs
39
+ >>> import brainstate.environ as env
40
+ >>>
41
+ >>> # Set global precision to 32-bit
42
+ >>> env.set(precision=32, dt=0.01, mode=bs.mixin.Training())
43
+ >>>
44
+ >>> # Get current settings
45
+ >>> print(env.get('precision')) # 32
46
+ >>> print(env.get('dt')) # 0.01
47
+
48
+ Context-based temporary settings:
49
+
50
+ .. code-block:: python
51
+
52
+ >>> import brainstate.environ as env
53
+ >>>
54
+ >>> # Temporarily change precision
55
+ >>> with env.context(precision=64, dt=0.001):
56
+ ... high_precision_result = compute_something()
57
+ ... print(env.get('precision')) # 64
58
+ >>> print(env.get('precision')) # Back to 32
59
+ """
17
60
 
18
61
  import contextlib
19
62
  import dataclasses
@@ -21,543 +64,1432 @@ import functools
21
64
  import os
22
65
  import re
23
66
  import threading
67
+ import warnings
24
68
  from collections import defaultdict
25
- from typing import Any, Callable, Dict, Hashable
69
+ from typing import Any, Callable, Dict, Hashable, Optional, Union, ContextManager, List
26
70
 
71
+ import brainunit as u
27
72
  import numpy as np
28
73
  from jax import config, devices, numpy as jnp
29
74
  from jax.typing import DTypeLike
30
75
 
31
- from .mixin import Mode
32
-
33
76
  __all__ = [
34
- # functions for environment settings
35
- 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
36
- # functions for getting default behaviors
37
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
38
- # functions for default data types
39
- 'dftype', 'ditype', 'dutype', 'dctype',
40
- # others
41
- 'tolerance', 'register_default_behavior',
77
+ # Core environment management
78
+ 'set',
79
+ 'get',
80
+ 'all',
81
+ 'pop',
82
+ 'context',
83
+ 'reset',
84
+
85
+ # Platform and device management
86
+ 'set_platform',
87
+ 'get_platform',
88
+ 'set_host_device_count',
89
+ 'get_host_device_count',
90
+
91
+ # Precision and data type management
92
+ 'set_precision',
93
+ 'get_precision',
94
+ 'dftype',
95
+ 'ditype',
96
+ 'dutype',
97
+ 'dctype',
98
+
99
+ # Mode and computation settings
100
+ 'get_dt',
101
+
102
+ # Utility functions
103
+ 'tolerance',
104
+ 'register_default_behavior',
105
+ 'unregister_default_behavior',
106
+ 'list_registered_behaviors',
107
+
108
+ # Constants
109
+ 'DEFAULT_PRECISION',
110
+ 'SUPPORTED_PLATFORMS',
111
+ 'SUPPORTED_PRECISIONS',
112
+
113
+ # Names
114
+ 'I',
115
+ 'T',
116
+ 'DT',
117
+ 'PRECISION',
118
+ 'PLATFORM',
119
+ 'HOST_DEVICE_COUNT',
120
+ 'JIT_ERROR_CHECK',
121
+ 'FIT',
42
122
  ]
43
123
 
44
- # Default, there are several shared arguments in the global context.
45
- I = 'i' # the index of the current computation.
46
- T = 't' # the current time of the current computation.
47
- JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
48
- FIT = 'fit' # whether to fit the model.
124
+ # Type definitions
125
+ # T = TypeVar('T')
126
+ PrecisionType = Union[int, str]
127
+ PlatformType = str
49
128
 
129
+ # Constants for environment keys
130
+ I = 'i' # Index of the current computation
131
+ T = 't' # Current time of the computation
132
+ DT = 'dt' # Time step for numerical integration
133
+ PRECISION = 'precision' # Numerical precision
134
+ PLATFORM = 'platform' # Computing platform
135
+ HOST_DEVICE_COUNT = 'host_device_count' # Number of host devices
136
+ JIT_ERROR_CHECK = 'jit_error_check' # JIT error checking flag
137
+ FIT = 'fit' # Model fitting flag
50
138
 
51
- @dataclasses.dataclass
52
- class DefaultContext(threading.local):
53
- # default environment settings
54
- settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
55
- # current environment settings
56
- contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
57
- # environment functions
58
- functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
139
+ # Default values
140
+ DEFAULT_PRECISION = 32
141
+ SUPPORTED_PLATFORMS = ('cpu', 'gpu', 'tpu')
142
+ SUPPORTED_PRECISIONS = (8, 16, 32, 64, 'bf16')
59
143
 
144
+ # Sentinel value for missing arguments
145
+ _NOT_PROVIDED = object()
60
146
 
61
- DFAULT = DefaultContext()
62
- _NOT_PROVIDE = object()
63
147
 
148
+ @dataclasses.dataclass
149
+ class EnvironmentState(threading.local):
150
+ """
151
+ Thread-local storage for environment configuration.
64
152
 
65
- @contextlib.contextmanager
66
- def context(**kwargs):
67
- r"""
68
- Context-manager that sets a computing environment for brain dynamics computation.
153
+ This class maintains separate configuration states for different threads,
154
+ ensuring thread-safe environment management in concurrent applications.
69
155
 
70
- In BrainPy, there are several basic computation settings when constructing models,
71
- including ``mode`` for controlling model computing behavior, ``dt`` for numerical
72
- integration, ``int_`` for integer precision, and ``float_`` for floating precision.
73
- :py:class:`~.environment`` provides a context for model construction and
74
- computation. In this temporal environment, models are constructed with the given
75
- ``mode``, ``dt``, ``int_``, etc., environment settings.
156
+ Attributes
157
+ ----------
158
+ settings : Dict[Hashable, Any]
159
+ Global default environment settings.
160
+ contexts : defaultdict[Hashable, List[Any]]
161
+ Stack of context-specific settings for nested contexts.
162
+ functions : Dict[Hashable, Callable]
163
+ Registered callback functions for environment changes.
164
+ locks : Dict[str, threading.Lock]
165
+ Thread locks for synchronized access to critical sections.
166
+ """
167
+ settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
168
+ contexts: defaultdict[Hashable, List[Any]] = dataclasses.field(default_factory=lambda: defaultdict(list))
169
+ functions: Dict[Hashable, Callable] = dataclasses.field(default_factory=dict)
170
+ locks: Dict[str, threading.Lock] = dataclasses.field(default_factory=lambda: defaultdict(threading.Lock))
76
171
 
77
- For instance::
172
+ def __post_init__(self):
173
+ """Initialize with default settings."""
174
+ # Set default precision if not already set
175
+ if PRECISION not in self.settings:
176
+ self.settings[PRECISION] = DEFAULT_PRECISION
78
177
 
79
- >>> import brainstate as brainstate
80
- >>> with brainstate.environ.context(dt=0.1) as env:
81
- ... dt = brainstate.environ.get('dt')
82
- ... print(env)
83
178
 
84
- """
85
- if 'platform' in kwargs:
86
- raise ValueError('\n'
87
- 'Cannot set platform in "context" environment. \n'
88
- 'You should set platform in the global environment by "set_platform()" or "set()".')
89
- if 'host_device_count' in kwargs:
90
- raise ValueError('Cannot set host_device_count in environment context. '
91
- 'Please use set_host_device_count() or set() for the global setting.')
179
+ # Global environment state
180
+ _ENV_STATE = EnvironmentState()
92
181
 
93
- if 'precision' in kwargs:
94
- last_precision = _get_precision()
95
- _set_jax_precision(kwargs['precision'])
96
182
 
97
- try:
98
- for k, v in kwargs.items():
183
+ def reset() -> None:
184
+ """
185
+ Reset the environment to default settings.
186
+
187
+ This function clears all custom settings and restores the environment
188
+ to its initial state. Useful for testing or when starting fresh.
189
+
190
+ Examples
191
+ --------
192
+ .. code-block:: python
193
+
194
+ >>> import brainstate.environ as env
195
+ >>>
196
+ >>> # Set custom values
197
+ >>> env.set(dt=0.1, custom_param='value')
198
+ >>> print(env.get('custom_param')) # 'value'
199
+ >>>
200
+ >>> # Reset to defaults
201
+ >>> env.reset()
202
+ >>> print(env.get('custom_param', default=None)) # None
203
+
204
+ Notes
205
+ -----
206
+ This operation cannot be undone. All custom settings will be lost.
207
+ """
208
+ global _ENV_STATE
209
+ _ENV_STATE = EnvironmentState()
210
+ # Re-apply default precision
211
+ _set_jax_precision(DEFAULT_PRECISION)
99
212
 
100
- # update the current environment
101
- DFAULT.contexts[k].append(v)
213
+ warnings.warn(
214
+ "Environment has been reset to default settings. "
215
+ "All custom configurations have been cleared.",
216
+ UserWarning
217
+ )
102
218
 
103
- # restore the environment functions
104
- if k in DFAULT.functions:
105
- DFAULT.functions[k](v)
106
219
 
107
- # yield the current all environment information
108
- yield all()
109
- finally:
220
+ @contextlib.contextmanager
221
+ def context(**kwargs) -> ContextManager[Dict[str, Any]]:
222
+ """
223
+ Context manager for temporary environment settings.
110
224
 
111
- for k, v in kwargs.items():
225
+ This context manager allows you to temporarily modify environment settings
226
+ within a specific scope. Settings are automatically restored when exiting
227
+ the context, even if an exception occurs.
112
228
 
113
- # restore the current environment
114
- DFAULT.contexts[k].pop()
229
+ Parameters
230
+ ----------
231
+ **kwargs
232
+ Environment settings to apply within the context.
233
+ Common parameters include:
234
+
235
+ - precision : int or str.
236
+ Numerical precision (8, 16, 32, 64, or 'bf16')
237
+ - dt : float.
238
+ Time step for numerical integration
239
+ - mode : :class:`Mode`.
240
+ Computation mode instance
241
+ - Any custom parameters registered via register_default_behavior
242
+
243
+ Yields
244
+ ------
245
+ dict
246
+ Current environment settings within the context.
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If attempting to set platform or host_device_count in context
252
+ (these must be set globally).
253
+ TypeError
254
+ If invalid parameter types are provided.
255
+
256
+ Examples
257
+ --------
258
+ Basic usage with precision control:
259
+
260
+ .. code-block:: python
261
+
262
+ >>> import brainstate.environ as env
263
+ >>>
264
+ >>> # Set global precision
265
+ >>> env.set(precision=32)
266
+ >>>
267
+ >>> # Temporarily use higher precision
268
+ >>> with env.context(precision=64) as ctx:
269
+ ... print(f"Precision in context: {env.get('precision')}") # 64
270
+ ... print(f"Float type: {env.dftype()}") # float64
271
+ >>>
272
+ >>> print(f"Precision after context: {env.get('precision')}") # 32
273
+
274
+ Nested contexts:
275
+
276
+ .. code-block:: python
277
+
278
+ >>> import brainstate.environ as env
279
+ >>>
280
+ >>> with env.context(dt=0.1) as ctx1:
281
+ ... print(f"dt = {env.get('dt')}") # 0.1
282
+ ...
283
+ ... with env.context(dt=0.01) as ctx2:
284
+ ... print(f"dt = {env.get('dt')}") # 0.01
285
+ ...
286
+ ... print(f"dt = {env.get('dt')}") # 0.1
287
+
288
+ Error handling in context:
289
+
290
+ .. code-block:: python
291
+
292
+ >>> import brainstate.environ as env
293
+ >>>
294
+ >>> env.set(value=10)
295
+ >>> try:
296
+ ... with env.context(value=20):
297
+ ... print(env.get('value')) # 20
298
+ ... raise ValueError("Something went wrong")
299
+ ... except ValueError:
300
+ ... pass
301
+ >>>
302
+ >>> print(env.get('value')) # 10 (restored)
303
+
304
+ Notes
305
+ -----
306
+ - Platform and host_device_count cannot be set in context
307
+ - Contexts can be nested arbitrarily deep
308
+ - Settings are restored in reverse order when exiting
309
+ - Thread-safe: each thread maintains its own context stack
310
+ """
311
+ # Validate restricted parameters
312
+ if PLATFORM in kwargs:
313
+ raise ValueError(
314
+ f"Cannot set '{PLATFORM}' in context. "
315
+ f"Use set_platform() or set() for global configuration."
316
+ )
317
+ if HOST_DEVICE_COUNT in kwargs:
318
+ raise ValueError(
319
+ f"Cannot set '{HOST_DEVICE_COUNT}' in context. "
320
+ f"Use set_host_device_count() or set() for global configuration."
321
+ )
322
+
323
+ # Handle precision changes
324
+ original_precision = None
325
+ if PRECISION in kwargs:
326
+ original_precision = _get_precision()
327
+ _validate_precision(kwargs[PRECISION])
328
+ _set_jax_precision(kwargs[PRECISION])
115
329
 
116
- # restore the environment functions
117
- if k in DFAULT.functions:
118
- DFAULT.functions[k](get(k))
330
+ try:
331
+ # Push new values onto context stacks
332
+ for key, value in kwargs.items():
333
+ with _ENV_STATE.locks[key]:
334
+ _ENV_STATE.contexts[key].append(value)
335
+
336
+ # Trigger registered callbacks
337
+ if key in _ENV_STATE.functions:
338
+ try:
339
+ _ENV_STATE.functions[key](value)
340
+ except Exception as e:
341
+ warnings.warn(
342
+ f"Callback for '{key}' raised an exception: {e}",
343
+ RuntimeWarning
344
+ )
345
+
346
+ # Yield current environment state
347
+ yield all()
119
348
 
120
- if 'precision' in kwargs:
121
- _set_jax_precision(last_precision)
349
+ finally:
350
+ # Restore previous values
351
+ for key in kwargs:
352
+ with _ENV_STATE.locks[key]:
353
+ if _ENV_STATE.contexts[key]:
354
+ _ENV_STATE.contexts[key].pop()
355
+
356
+ # Restore callbacks with previous value
357
+ if key in _ENV_STATE.functions:
358
+ try:
359
+ prev_value = get(key, default=None)
360
+ if prev_value is not None:
361
+ _ENV_STATE.functions[key](prev_value)
362
+ except Exception as e:
363
+ warnings.warn(
364
+ f"Callback restoration for '{key}' raised: {e}",
365
+ RuntimeWarning
366
+ )
367
+
368
+ # Restore precision if it was changed
369
+ if original_precision is not None:
370
+ _set_jax_precision(original_precision)
371
+
372
+
373
+ def get(key: str, default: Any = _NOT_PROVIDED, desc: Optional[str] = None) -> Any:
374
+ """
375
+ Get a value from the current environment.
122
376
 
377
+ This function retrieves values from the environment, checking first in
378
+ the context stack, then in global settings. Special handling is provided
379
+ for platform and device count parameters.
123
380
 
124
- def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
125
- """
126
- Get one of the default computation environment.
381
+ Parameters
382
+ ----------
383
+ key : str
384
+ The environment key to retrieve.
385
+ default : Any, optional
386
+ Default value to return if key is not found.
387
+ If not provided, raises KeyError for missing keys.
388
+ desc : str, optional
389
+ Description of the parameter for error messages.
127
390
 
128
391
  Returns
129
392
  -------
130
- item: Any
131
- The default computation environment.
393
+ Any
394
+ The value associated with the key.
395
+
396
+ Raises
397
+ ------
398
+ KeyError
399
+ If key is not found and no default is provided.
400
+
401
+ Examples
402
+ --------
403
+ Basic retrieval:
404
+
405
+ .. code-block:: python
406
+
407
+ >>> import brainstate.environ as env
408
+ >>>
409
+ >>> env.set(learning_rate=0.001)
410
+ >>> lr = env.get('learning_rate')
411
+ >>> print(f"Learning rate: {lr}") # 0.001
412
+
413
+ With default value:
414
+
415
+ .. code-block:: python
416
+
417
+ >>> import brainstate.environ as env
418
+ >>>
419
+ >>> # Get with default
420
+ >>> batch_size = env.get('batch_size', default=32)
421
+ >>> print(f"Batch size: {batch_size}") # 32
422
+
423
+ Context-aware retrieval:
424
+
425
+ .. code-block:: python
426
+
427
+ >>> import brainstate.environ as env
428
+ >>>
429
+ >>> env.set(temperature=1.0)
430
+ >>> print(env.get('temperature')) # 1.0
431
+ >>>
432
+ >>> with env.context(temperature=0.5):
433
+ ... print(env.get('temperature')) # 0.5
434
+ >>>
435
+ >>> print(env.get('temperature')) # 1.0
436
+
437
+ Notes
438
+ -----
439
+ Special keys 'platform' and 'host_device_count' are handled separately
440
+ and retrieve system-level information.
132
441
  """
133
- if key == 'platform':
442
+ # Special cases for platform-specific parameters
443
+ if key == PLATFORM:
134
444
  return get_platform()
135
-
136
- if key == 'host_device_count':
445
+ if key == HOST_DEVICE_COUNT:
137
446
  return get_host_device_count()
138
447
 
139
- if key in DFAULT.contexts:
140
- if len(DFAULT.contexts[key]) > 0:
141
- return DFAULT.contexts[key][-1]
142
- if key in DFAULT.settings:
143
- return DFAULT.settings[key]
144
-
145
- if default is _NOT_PROVIDE:
146
- if desc is not None:
147
- raise KeyError(
148
- f"'{key}' is not found in the context. \n"
149
- f"You can set it by `brainstate.share.context({key}=value)` "
150
- f"locally or `brainstate.share.set({key}=value)` globally. \n"
151
- f"Description: {desc}"
152
- )
153
- else:
154
- raise KeyError(
155
- f"'{key}' is not found in the context. \n"
156
- f"You can set it by `brainstate.share.context({key}=value)` "
157
- f"locally or `brainstate.share.set({key}=value)` globally."
158
- )
448
+ # Check context stack first (most recent value)
449
+ with _ENV_STATE.locks[key]:
450
+ if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
451
+ return _ENV_STATE.contexts[key][-1]
452
+
453
+ # Check global settings
454
+ if key in _ENV_STATE.settings:
455
+ return _ENV_STATE.settings[key]
456
+
457
+ # Handle missing key
458
+ if default is _NOT_PROVIDED:
459
+ error_msg = f"Key '{key}' not found in environment."
460
+ if desc:
461
+ error_msg += f" Description: {desc}"
462
+ error_msg += (
463
+ f"\nSet it using:\n"
464
+ f" - env.set({key}=value) for global setting\n"
465
+ f" - env.context({key}=value) for temporary setting"
466
+ )
467
+ raise KeyError(error_msg)
468
+
159
469
  return default
160
470
 
161
471
 
162
- def all() -> dict:
472
+ def all() -> Dict[str, Any]:
163
473
  """
164
- Get all the current default computation environment.
474
+ Get all current environment settings.
475
+
476
+ This function returns a dictionary containing all active environment
477
+ settings, with context values taking precedence over global settings.
165
478
 
166
479
  Returns
167
480
  -------
168
- r: dict
169
- The current default computation environment.
481
+ dict
482
+ Dictionary of all current environment settings.
483
+
484
+ Examples
485
+ --------
486
+ .. code-block:: python
487
+
488
+ >>> import brainstate.environ as env
489
+ >>>
490
+ >>> # Set various parameters
491
+ >>> env.set(precision=32, dt=0.01, debug=True)
492
+ >>>
493
+ >>> # Get all settings
494
+ >>> settings = env.all()
495
+ >>> print(settings)
496
+ {'precision': 32, 'dt': 0.01, 'debug': True}
497
+
498
+ >>> # Context overrides
499
+ >>> with env.context(precision=64, new_param='test'):
500
+ ... settings = env.all()
501
+ ... print(settings['precision']) # 64
502
+ ... print(settings['new_param']) # 'test'
503
+
504
+ Notes
505
+ -----
506
+ The returned dictionary is a snapshot and modifying it does not
507
+ affect the environment settings.
508
+ """
509
+ result = {}
510
+
511
+ # Add global settings
512
+ result.update(_ENV_STATE.settings)
513
+
514
+ # Override with context values (most recent)
515
+ for key, values in _ENV_STATE.contexts.items():
516
+ if values:
517
+ result[key] = values[-1]
518
+
519
+ return result
520
+
521
+
522
+ def pop(key: str, default: Any = _NOT_PROVIDED) -> Any:
170
523
  """
171
- r = dict()
172
- for k, v in DFAULT.contexts.items():
173
- if v:
174
- r[k] = v[-1]
175
- for k, v in DFAULT.settings.items():
176
- if k not in r:
177
- r[k] = v
178
- return r
524
+ Remove and return a value from the global environment.
179
525
 
526
+ This function removes a key from the global environment settings and
527
+ returns its value. If the key is not found, it returns the default
528
+ value if provided, or raises KeyError.
180
529
 
181
- def get_dt():
182
- """Get the numerical integrator precision.
530
+ Note that this function only affects global settings, not context values.
531
+ Keys in active contexts are not affected.
532
+
533
+ Parameters
534
+ ----------
535
+ key : str
536
+ The environment key to remove.
537
+ default : Any, optional
538
+ Default value to return if key is not found.
539
+ If not provided, raises KeyError for missing keys.
183
540
 
184
541
  Returns
185
542
  -------
186
- dt : float
187
- Numerical integration precision.
543
+ Any
544
+ The value that was removed from the environment.
545
+
546
+ Raises
547
+ ------
548
+ KeyError
549
+ If key is not found and no default is provided.
550
+ ValueError
551
+ If attempting to pop a key that is currently in a context.
552
+
553
+ Examples
554
+ --------
555
+ Basic usage:
556
+
557
+ .. code-block:: python
558
+
559
+ >>> import brainstate.environ as env
560
+ >>>
561
+ >>> # Set a value
562
+ >>> env.set(temp_param='temporary')
563
+ >>> print(env.get('temp_param')) # 'temporary'
564
+ >>>
565
+ >>> # Pop the value
566
+ >>> value = env.pop('temp_param')
567
+ >>> print(value) # 'temporary'
568
+ >>>
569
+ >>> # Value is now gone
570
+ >>> env.get('temp_param', default=None) # None
571
+
572
+ With default value:
573
+
574
+ .. code-block:: python
575
+
576
+ >>> import brainstate.environ as env
577
+ >>>
578
+ >>> # Pop non-existent key with default
579
+ >>> value = env.pop('nonexistent', default='default_value')
580
+ >>> print(value) # 'default_value'
581
+
582
+ Pop multiple values:
583
+
584
+ .. code-block:: python
585
+
586
+ >>> import brainstate.environ as env
587
+ >>>
588
+ >>> # Set multiple values
589
+ >>> env.set(param1='value1', param2='value2', param3='value3')
590
+ >>>
591
+ >>> # Pop them one by one
592
+ >>> v1 = env.pop('param1')
593
+ >>> v2 = env.pop('param2')
594
+ >>>
595
+ >>> # param3 still exists
596
+ >>> print(env.get('param3')) # 'value3'
597
+
598
+ Context protection:
599
+
600
+ .. code-block:: python
601
+
602
+ >>> import brainstate.environ as env
603
+ >>>
604
+ >>> env.set(protected='global_value')
605
+ >>>
606
+ >>> with env.context(protected='context_value'):
607
+ ... # Cannot pop a key that's in active context
608
+ ... try:
609
+ ... env.pop('protected')
610
+ ... except ValueError as e:
611
+ ... print("Cannot pop key in active context")
612
+
613
+ Notes
614
+ -----
615
+ - This function only removes keys from global settings
616
+ - Keys that are currently overridden in active contexts cannot be popped
617
+ - Special keys like 'platform' and 'host_device_count' can be popped but
618
+ their system-level values remain accessible through get_platform() etc.
619
+ - Registered callbacks are NOT triggered when popping values
188
620
  """
189
- return get('dt')
621
+ # Check if key is currently in any active context
622
+ if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
623
+ raise ValueError(
624
+ f"Cannot pop key '{key}' while it is active in a context. "
625
+ f"The key is currently overridden in {len(_ENV_STATE.contexts[key])} context(s)."
626
+ )
627
+
628
+ # Check if key exists in global settings
629
+ if key in _ENV_STATE.settings:
630
+ # Remove and return the value
631
+ value = _ENV_STATE.settings.pop(key)
632
+
633
+ # Note: We don't trigger callbacks here as this is a removal operation
634
+ # If needed, users can register callbacks for removal separately
635
+
636
+ return value
637
+
638
+ # Key not found, handle default
639
+ if default is _NOT_PROVIDED:
640
+ raise KeyError(f"Key '{key}' not found in global environment settings.")
641
+
642
+ return default
643
+
190
644
 
645
+ def set(
646
+ platform: Optional[PlatformType] = None,
647
+ host_device_count: Optional[int] = None,
648
+ precision: Optional[PrecisionType] = None,
649
+ dt: Optional[float] = None,
650
+ **kwargs
651
+ ) -> None:
652
+ """
653
+ Set global environment configuration.
191
654
 
192
- def get_mode() -> Mode:
193
- """Get the default computing mode.
655
+ This function sets persistent global environment settings that remain
656
+ active until explicitly changed or the program terminates.
194
657
 
195
- References
658
+ Parameters
196
659
  ----------
197
- mode: Mode
198
- The default computing mode.
660
+ platform : str, optional
661
+ Computing platform ('cpu', 'gpu', or 'tpu').
662
+ host_device_count : int, optional
663
+ Number of host devices for parallel computation.
664
+ precision : int or str, optional
665
+ Numerical precision (8, 16, 32, 64, or 'bf16').
666
+ mode : Mode, optional
667
+ Computation mode instance.
668
+ dt : float, optional
669
+ Time step for numerical integration.
670
+ **kwargs
671
+ Additional custom environment parameters.
672
+
673
+ Raises
674
+ ------
675
+ ValueError
676
+ If invalid platform or precision is specified.
677
+ TypeError
678
+ If mode is not a Mode instance.
679
+
680
+ Examples
681
+ --------
682
+ Basic configuration:
683
+
684
+ .. code-block:: python
685
+
686
+ >>> import brainstate as bs
687
+ >>> import brainstate.environ as env
688
+ >>>
689
+ >>> # Set multiple parameters
690
+ >>> env.set(
691
+ ... precision=32,
692
+ ... dt=0.01,
693
+ ... mode=bs.mixin.Training(),
694
+ ... debug=False
695
+ ... )
696
+ >>>
697
+ >>> print(env.get('precision')) # 32
698
+ >>> print(env.get('dt')) # 0.01
699
+
700
+ Platform configuration:
701
+
702
+ .. code-block:: python
703
+
704
+ >>> import brainstate.environ as env
705
+ >>>
706
+ >>> # Configure for GPU computation
707
+ >>> env.set(platform='gpu', precision=16)
708
+ >>>
709
+ >>> # Configure for multi-core CPU
710
+ >>> env.set(platform='cpu', host_device_count=4)
711
+
712
+ Custom parameters:
713
+
714
+ .. code-block:: python
715
+
716
+ >>> import brainstate.environ as env
717
+ >>>
718
+ >>> # Set custom parameters
719
+ >>> env.set(
720
+ ... experiment_name='test_001',
721
+ ... random_seed=42,
722
+ ... log_level='DEBUG'
723
+ ... )
724
+ >>>
725
+ >>> # Retrieve custom parameters
726
+ >>> print(env.get('experiment_name')) # 'test_001'
727
+
728
+ Notes
729
+ -----
730
+ - Platform changes only take effect at program start
731
+ - Some JAX configurations require restart to take effect
732
+ - Custom parameters can be any hashable key-value pairs
199
733
  """
200
- return get('mode')
734
+ # Handle special parameters
735
+ if platform is not None:
736
+ set_platform(platform)
201
737
 
738
+ if host_device_count is not None:
739
+ set_host_device_count(host_device_count)
202
740
 
203
- def get_platform() -> str:
204
- """Get the computing platform.
741
+ if precision is not None:
742
+ _validate_precision(precision)
743
+ _set_jax_precision(precision)
744
+ kwargs[PRECISION] = precision
205
745
 
206
- Returns
207
- -------
208
- platform: str
209
- Either 'cpu', 'gpu' or 'tpu'.
210
- """
211
- return devices()[0].platform
746
+ if dt is not None:
747
+ if not u.math.isscalar(dt):
748
+ raise TypeError(f"'{DT}' must be a scalar number, got {type(dt)}")
749
+ kwargs[DT] = dt
212
750
 
751
+ # Update global settings
752
+ _ENV_STATE.settings.update(kwargs)
213
753
 
214
- def get_host_device_count():
754
+ # Trigger registered callbacks
755
+ for key, value in kwargs.items():
756
+ if key in _ENV_STATE.functions:
757
+ try:
758
+ _ENV_STATE.functions[key](value)
759
+ except Exception as e:
760
+ warnings.warn(
761
+ f"Callback for '{key}' raised an exception: {e}",
762
+ RuntimeWarning
763
+ )
764
+
765
+
766
+ def get_dt() -> float:
215
767
  """
216
- Get the number of host devices.
768
+ Get the current numerical integration time step.
217
769
 
218
770
  Returns
219
771
  -------
220
- n: int
221
- The number of host devices.
772
+ float
773
+ The time step value.
774
+
775
+ Raises
776
+ ------
777
+ KeyError
778
+ If dt is not set.
779
+
780
+ Examples
781
+ --------
782
+ .. code-block:: python
783
+
784
+ >>> import brainstate.environ as env
785
+ >>>
786
+ >>> env.set(dt=0.01)
787
+ >>> dt = env.get_dt()
788
+ >>> print(f"Time step: {dt} ms") # Time step: 0.01 ms
789
+ >>>
790
+ >>> # Use in computation
791
+ >>> with env.context(dt=0.001):
792
+ ... fine_dt = env.get_dt()
793
+ ... print(f"Fine time step: {fine_dt}") # 0.001
222
794
  """
223
- xla_flags = os.getenv("XLA_FLAGS", "")
224
- match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
225
- return int(match.group(1)) if match else 1
795
+ return get(DT)
226
796
 
227
797
 
228
- def _get_precision() -> int | str:
798
+ def get_platform() -> PlatformType:
229
799
  """
230
- Get the default precision.
800
+ Get the current computing platform.
231
801
 
232
802
  Returns
233
803
  -------
234
- precision: int
235
- The default precision.
804
+ str
805
+ Platform name ('cpu', 'gpu', or 'tpu').
806
+
807
+ Examples
808
+ --------
809
+ .. code-block:: python
810
+
811
+ >>> import brainstate.environ as env
812
+ >>>
813
+ >>> platform = env.get_platform()
814
+ >>> print(f"Running on: {platform}")
815
+ >>>
816
+ >>> if platform == 'gpu':
817
+ ... print("GPU acceleration available")
818
+ ... else:
819
+ ... print(f"Using {platform.upper()}")
236
820
  """
237
- return get('precision')
821
+ return devices()[0].platform
238
822
 
239
823
 
240
- def get_precision() -> int:
824
+ def get_host_device_count() -> int:
241
825
  """
242
- Get the default precision.
826
+ Get the number of host devices.
243
827
 
244
828
  Returns
245
829
  -------
246
- precision: int
247
- The default precision.
830
+ int
831
+ Number of host devices configured.
832
+
833
+ Examples
834
+ --------
835
+ .. code-block:: python
836
+
837
+ >>> import brainstate.environ as env
838
+ >>>
839
+ >>> # Get device count
840
+ >>> n_devices = env.get_host_device_count()
841
+ >>> print(f"Host devices: {n_devices}")
842
+ >>>
843
+ >>> # Configure for parallel computation
844
+ >>> if n_devices > 1:
845
+ ... print(f"Can use {n_devices} devices for parallel computation")
248
846
  """
249
- precision = get('precision')
250
- if precision == 'bf16':
251
- return 16
252
- if isinstance(precision, int):
253
- return precision
254
- if isinstance(precision, str):
255
- return int(precision)
256
- raise ValueError(f'Unsupported precision: {precision}')
847
+ xla_flags = os.getenv("XLA_FLAGS", "")
848
+ match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
849
+ return int(match.group(1)) if match else 1
257
850
 
258
851
 
259
- def set(
260
- platform: str = None,
261
- host_device_count: int = None,
262
- precision: int | str = None,
263
- mode: Mode = None,
264
- **kwargs
265
- ):
852
+ def set_platform(platform: PlatformType) -> None:
266
853
  """
267
- Set the global default computation environment.
268
-
854
+ Set the computing platform.
269
855
 
270
-
271
- Args:
272
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
273
- host_device_count: int. The number of host devices.
274
- precision: int, str. The default precision.
275
- mode: Mode. The computing mode.
276
- **kwargs: dict. Other environment settings.
856
+ Parameters
857
+ ----------
858
+ platform : str
859
+ Platform to use ('cpu', 'gpu', or 'tpu').
860
+
861
+ Raises
862
+ ------
863
+ ValueError
864
+ If platform is not supported.
865
+
866
+ Examples
867
+ --------
868
+ .. code-block:: python
869
+
870
+ >>> import brainstate.environ as env
871
+ >>>
872
+ >>> # Set to GPU
873
+ >>> env.set_platform('gpu')
874
+ >>>
875
+ >>> # Verify platform
876
+ >>> print(env.get_platform()) # 'gpu'
877
+
878
+ Notes
879
+ -----
880
+ Platform changes only take effect at program start. Changing platform
881
+ after JAX initialization may not have the expected effect.
277
882
  """
278
- if platform is not None:
279
- set_platform(platform)
280
- if host_device_count is not None:
281
- set_host_device_count(host_device_count)
282
- if precision is not None:
283
- _set_jax_precision(precision)
284
- kwargs['precision'] = precision
285
- if mode is not None:
286
- assert isinstance(mode, Mode), 'mode must be a Mode instance.'
287
- kwargs['mode'] = mode
883
+ if platform not in SUPPORTED_PLATFORMS:
884
+ raise ValueError(
885
+ f"Platform must be one of {SUPPORTED_PLATFORMS}, got '{platform}'"
886
+ )
288
887
 
289
- # set default environment
290
- DFAULT.settings.update(kwargs)
888
+ config.update("jax_platform_name", platform)
291
889
 
292
- # update the environment functions
293
- for k, v in kwargs.items():
294
- if k in DFAULT.functions:
295
- DFAULT.functions[k](v)
890
+ # Trigger callbacks
891
+ if PLATFORM in _ENV_STATE.functions:
892
+ _ENV_STATE.functions[PLATFORM](platform)
296
893
 
297
894
 
298
- def set_host_device_count(n):
895
+ def set_host_device_count(n: int) -> None:
299
896
  """
300
- By default, XLA considers all CPU cores as one device. This utility tells XLA
301
- that there are `n` host (CPU) devices available to use. As a consequence, this
302
- allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
897
+ Set the number of host (CPU) devices.
303
898
 
304
- .. note:: This utility only takes effect at the beginning of your program.
305
- Under the hood, this sets the environment variable
306
- `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
307
- `[num_device]` is the desired number of CPU devices `n`.
899
+ This function configures XLA to treat CPU cores as separate devices,
900
+ enabling parallel computation with jax.pmap on CPU.
308
901
 
309
- .. warning:: Our understanding of the side effects of using the
310
- `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
311
- observe some strange phenomenon when using this utility, please let us
312
- know through our issue or forum page. More information is available in this
313
- `JAX issue <https://github.com/google/jax/issues/1408>`_.
314
-
315
- :param int n: number of devices to use.
902
+ Parameters
903
+ ----------
904
+ n : int
905
+ Number of host devices to configure.
906
+
907
+ Raises
908
+ ------
909
+ ValueError
910
+ If n is not a positive integer.
911
+
912
+ Examples
913
+ --------
914
+ .. code-block:: python
915
+
916
+ >>> import brainstate.environ as env
917
+ >>> import jax
918
+ >>>
919
+ >>> # Configure 4 CPU devices
920
+ >>> env.set_host_device_count(4)
921
+ >>>
922
+ >>> # Use with pmap
923
+ >>> def parallel_fn(x):
924
+ ... return x * 2
925
+ >>>
926
+ >>> # This will work with 4 devices
927
+ >>> pmapped_fn = jax.pmap(parallel_fn)
928
+
929
+ Warnings
930
+ --------
931
+ This setting only takes effect at program start. The effects of using
932
+ xla_force_host_platform_device_count are not fully understood and may
933
+ cause unexpected behavior.
316
934
  """
935
+ if not isinstance(n, int) or n < 1:
936
+ raise ValueError(f"Host device count must be a positive integer, got {n}")
937
+
938
+ # Update XLA flags
317
939
  xla_flags = os.getenv("XLA_FLAGS", "")
318
- xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
319
- os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
940
+ xla_flags = re.sub(
941
+ r"--xla_force_host_platform_device_count=\S+",
942
+ "",
943
+ xla_flags
944
+ ).split()
320
945
 
321
- # update the environment functions
322
- if 'host_device_count' in DFAULT.functions:
323
- DFAULT.functions['host_device_count'](n)
946
+ os.environ["XLA_FLAGS"] = " ".join(
947
+ [f"--xla_force_host_platform_device_count={n}"] + xla_flags
948
+ )
324
949
 
950
+ # Trigger callbacks
951
+ if HOST_DEVICE_COUNT in _ENV_STATE.functions:
952
+ _ENV_STATE.functions[HOST_DEVICE_COUNT](n)
325
953
 
326
- def set_platform(platform: str):
327
- """
328
- Changes platform to CPU, GPU, or TPU. This utility only takes
329
- effect at the beginning of your program.
330
954
 
331
- Args:
332
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
955
+ def set_precision(precision: PrecisionType) -> None:
956
+ """
957
+ Set the global numerical precision.
333
958
 
334
- Raises:
335
- ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
959
+ Parameters
960
+ ----------
961
+ precision : int or str
962
+ Precision to use (8, 16, 32, 64, or 'bf16').
963
+
964
+ Raises
965
+ ------
966
+ ValueError
967
+ If precision is not supported.
968
+
969
+ Examples
970
+ --------
971
+ .. code-block:: python
972
+
973
+ >>> import brainstate.environ as env
974
+ >>> import jax.numpy as jnp
975
+ >>>
976
+ >>> # Set to 64-bit precision
977
+ >>> env.set_precision(64)
978
+ >>>
979
+ >>> # Arrays will use float64 by default
980
+ >>> x = jnp.array([1.0, 2.0, 3.0])
981
+ >>> print(x.dtype) # float64
982
+ >>>
983
+ >>> # Set to bfloat16 for efficiency
984
+ >>> env.set_precision('bf16')
336
985
  """
337
- assert platform in ['cpu', 'gpu', 'tpu']
338
- config.update("jax_platform_name", platform)
986
+ _validate_precision(precision)
987
+ _set_jax_precision(precision)
988
+ _ENV_STATE.settings[PRECISION] = precision
339
989
 
340
- # update the environment functions
341
- if 'platform' in DFAULT.functions:
342
- DFAULT.functions['platform'](platform)
990
+ # Trigger callbacks
991
+ if PRECISION in _ENV_STATE.functions:
992
+ _ENV_STATE.functions[PRECISION](precision)
343
993
 
344
994
 
345
- def _set_jax_precision(precision: int | str):
995
+ def get_precision() -> int:
346
996
  """
347
- Set the default precision.
997
+ Get the current numerical precision as an integer.
348
998
 
349
- Args:
350
- precision: int. The default precision.
999
+ Returns
1000
+ -------
1001
+ int
1002
+ Precision in bits (8, 16, 32, or 64).
1003
+
1004
+ Examples
1005
+ --------
1006
+ .. code-block:: python
1007
+
1008
+ >>> import brainstate.environ as env
1009
+ >>>
1010
+ >>> env.set_precision(32)
1011
+ >>> bits = env.get_precision()
1012
+ >>> print(f"Using {bits}-bit precision") # Using 32-bit precision
1013
+ >>>
1014
+ >>> # Special handling for bfloat16
1015
+ >>> env.set_precision('bf16')
1016
+ >>> print(env.get_precision()) # 16
1017
+
1018
+ Notes
1019
+ -----
1020
+ 'bf16' (bfloat16) is reported as 16-bit precision.
351
1021
  """
352
- # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
353
- if precision in [64, '64']:
1022
+ precision = get(PRECISION, default=DEFAULT_PRECISION)
1023
+
1024
+ if precision == 'bf16':
1025
+ return 16
1026
+ elif isinstance(precision, str):
1027
+ return int(precision)
1028
+ elif isinstance(precision, int):
1029
+ return precision
1030
+ else:
1031
+ raise ValueError(f"Invalid precision type: {type(precision)}")
1032
+
1033
+
1034
+ def _validate_precision(precision: PrecisionType) -> None:
1035
+ """Validate precision value."""
1036
+ if precision not in SUPPORTED_PRECISIONS and str(precision) not in map(str, SUPPORTED_PRECISIONS):
1037
+ raise ValueError(
1038
+ f"Precision must be one of {SUPPORTED_PRECISIONS}, got {precision}"
1039
+ )
1040
+
1041
+
1042
+ def _get_precision() -> PrecisionType:
1043
+ """Get raw precision value (including 'bf16')."""
1044
+ return get(PRECISION, default=DEFAULT_PRECISION)
1045
+
1046
+
1047
+ def _set_jax_precision(precision: PrecisionType) -> None:
1048
+ """Configure JAX precision settings."""
1049
+ # Enable/disable 64-bit mode
1050
+ if precision in (64, '64'):
354
1051
  config.update("jax_enable_x64", True)
355
1052
  else:
356
1053
  config.update("jax_enable_x64", False)
357
1054
 
358
1055
 
359
- @functools.lru_cache()
360
- def _get_uint(precision: int):
361
- if precision in [64, '64']:
1056
+ @functools.lru_cache(maxsize=16)
1057
+ def _get_uint(precision: PrecisionType) -> DTypeLike:
1058
+ """Get unsigned integer type for given precision."""
1059
+ if precision in (64, '64'):
362
1060
  return np.uint64
363
- elif precision in [32, '32']:
1061
+ elif precision in (32, '32'):
364
1062
  return np.uint32
365
- elif precision in [16, '16', 'bf16']:
1063
+ elif precision in (16, '16', 'bf16'):
366
1064
  return np.uint16
367
- elif precision in [8, '8']:
1065
+ elif precision in (8, '8'):
368
1066
  return np.uint8
369
1067
  else:
370
- raise ValueError(f'Unsupported precision: {precision}')
1068
+ raise ValueError(f"Unsupported precision: {precision}")
371
1069
 
372
1070
 
373
- @functools.lru_cache()
374
- def _get_int(precision: int):
375
- if precision in [64, '64']:
1071
+ @functools.lru_cache(maxsize=16)
1072
+ def _get_int(precision: PrecisionType) -> DTypeLike:
1073
+ """Get integer type for given precision."""
1074
+ if precision in (64, '64'):
376
1075
  return np.int64
377
- elif precision in [32, '32']:
1076
+ elif precision in (32, '32'):
378
1077
  return np.int32
379
- elif precision in [16, '16', 'bf16']:
1078
+ elif precision in (16, '16', 'bf16'):
380
1079
  return np.int16
381
- elif precision in [8, '8']:
1080
+ elif precision in (8, '8'):
382
1081
  return np.int8
383
1082
  else:
384
- raise ValueError(f'Unsupported precision: {precision}')
1083
+ raise ValueError(f"Unsupported precision: {precision}")
385
1084
 
386
1085
 
387
- @functools.lru_cache()
388
- def _get_float(precision: int):
389
- if precision in [64, '64']:
1086
+ @functools.lru_cache(maxsize=16)
1087
+ def _get_float(precision: PrecisionType) -> DTypeLike:
1088
+ """Get floating-point type for given precision."""
1089
+ if precision in (64, '64'):
390
1090
  return np.float64
391
- elif precision in [32, '32']:
1091
+ elif precision in (32, '32'):
392
1092
  return np.float32
393
- elif precision in [16, '16']:
1093
+ elif precision in (16, '16'):
394
1094
  return np.float16
395
- elif precision in ['bf16']:
1095
+ elif precision == 'bf16':
396
1096
  return jnp.bfloat16
397
- elif precision in [8, '8']:
1097
+ elif precision in (8, '8'):
398
1098
  return jnp.float8_e5m2
399
1099
  else:
400
- raise ValueError(f'Unsupported precision: {precision}')
1100
+ raise ValueError(f"Unsupported precision: {precision}")
401
1101
 
402
1102
 
403
- @functools.lru_cache()
404
- def _get_complex(precision: int):
405
- if precision == [64, '64']:
1103
+ @functools.lru_cache(maxsize=16)
1104
+ def _get_complex(precision: PrecisionType) -> DTypeLike:
1105
+ """Get complex type for given precision."""
1106
+ if precision in (64, '64'):
406
1107
  return np.complex128
407
- elif precision == [32, '32']:
408
- return np.complex64
409
- elif precision in [16, '16', 'bf16']:
410
- return np.complex64
411
- elif precision == [8, '8']:
1108
+ elif precision in (32, '32', 16, '16', 'bf16', 8, '8'):
412
1109
  return np.complex64
413
1110
  else:
414
- raise ValueError(f'Unsupported precision: {precision}')
1111
+ raise ValueError(f"Unsupported precision: {precision}")
415
1112
 
416
1113
 
417
1114
  def dftype() -> DTypeLike:
418
1115
  """
419
- Default floating data type.
1116
+ Get the default floating-point data type.
420
1117
 
421
- This function returns the default floating data type based on the current precision.
422
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
423
- you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
424
-
425
- For example, if the precision is set to 32, the default floating data type is ``np.float32``.
426
-
427
- >>> import brainstate as brainstate
428
- >>> import numpy as np
429
- >>> with brainstate.environ.context(precision=32):
430
- ... a = np.zeros(1, dtype=brainstate.environ.dftype())
431
- >>> print(a.dtype)
1118
+ This function returns the appropriate floating-point type based on
1119
+ the current precision setting, allowing dynamic type selection.
432
1120
 
433
1121
  Returns
434
1122
  -------
435
- float_dtype: DTypeLike
436
- The default floating data type.
1123
+ DTypeLike
1124
+ Default floating-point data type.
1125
+
1126
+ Examples
1127
+ --------
1128
+ .. code-block:: python
1129
+
1130
+ >>> import brainstate.environ as env
1131
+ >>> import jax.numpy as jnp
1132
+ >>>
1133
+ >>> # With 32-bit precision
1134
+ >>> env.set(precision=32)
1135
+ >>> x = jnp.zeros(10, dtype=env.dftype())
1136
+ >>> print(x.dtype) # float32
1137
+ >>>
1138
+ >>> # With 64-bit precision
1139
+ >>> with env.context(precision=64):
1140
+ ... y = jnp.ones(5, dtype=env.dftype())
1141
+ ... print(y.dtype) # float64
1142
+ >>>
1143
+ >>> # With bfloat16
1144
+ >>> env.set(precision='bf16')
1145
+ >>> z = jnp.array([1, 2, 3], dtype=env.dftype())
1146
+ >>> print(z.dtype) # bfloat16
1147
+
1148
+ See Also
1149
+ --------
1150
+ ditype : Default integer type
1151
+ dutype : Default unsigned integer type
1152
+ dctype : Default complex type
437
1153
  """
438
1154
  return _get_float(_get_precision())
439
1155
 
440
1156
 
441
1157
  def ditype() -> DTypeLike:
442
1158
  """
443
- Default integer data type.
444
-
445
- This function returns the default integer data type based on the current precision.
446
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
447
- you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
448
-
449
- For example, if the precision is set to 32, the default integer data type is ``np.int32``.
1159
+ Get the default integer data type.
450
1160
 
451
- >>> import brainstate as brainstate
452
- >>> import numpy as np
453
- >>> with brainstate.environ.context(precision=32):
454
- ... a = np.zeros(1, dtype=brainstate.environ.ditype())
455
- >>> print(a.dtype)
456
- int32
1161
+ This function returns the appropriate integer type based on
1162
+ the current precision setting.
457
1163
 
458
1164
  Returns
459
1165
  -------
460
- int_dtype: DTypeLike
461
- The default integer data type.
1166
+ DTypeLike
1167
+ Default integer data type.
1168
+
1169
+ Examples
1170
+ --------
1171
+ .. code-block:: python
1172
+
1173
+ >>> import brainstate.environ as env
1174
+ >>> import jax.numpy as jnp
1175
+ >>>
1176
+ >>> # With 32-bit precision
1177
+ >>> env.set(precision=32)
1178
+ >>> indices = jnp.arange(10, dtype=env.ditype())
1179
+ >>> print(indices.dtype) # int32
1180
+ >>>
1181
+ >>> # With 64-bit precision
1182
+ >>> with env.context(precision=64):
1183
+ ... big_indices = jnp.arange(1000, dtype=env.ditype())
1184
+ ... print(big_indices.dtype) # int64
1185
+
1186
+ See Also
1187
+ --------
1188
+ dftype : Default floating-point type
1189
+ dutype : Default unsigned integer type
462
1190
  """
463
1191
  return _get_int(_get_precision())
464
1192
 
465
1193
 
466
1194
  def dutype() -> DTypeLike:
467
1195
  """
468
- Default unsigned integer data type.
1196
+ Get the default unsigned integer data type.
469
1197
 
470
- This function returns the default unsigned integer data type based on the current precision.
471
- If you want the data dtype is changed with the setting of the precision
472
- by ``brainstate.environ.set(precision)``, you can use this function to get the default
473
- unsigned integer data type, and create the data by using ``dtype=dutype()``.
474
-
475
- For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
476
-
477
- >>> import brainstate as brainstate
478
- >>> import numpy as np
479
- >>> with brainstate.environ.context(precision=32):
480
- ... a = np.zeros(1, dtype=brainstate.environ.dutype())
481
- >>> print(a.dtype)
482
- uint32
1198
+ This function returns the appropriate unsigned integer type based on
1199
+ the current precision setting.
483
1200
 
484
1201
  Returns
485
1202
  -------
486
- uint_dtype: DTypeLike
487
- The default unsigned integer data type.
1203
+ DTypeLike
1204
+ Default unsigned integer data type.
1205
+
1206
+ Examples
1207
+ --------
1208
+ .. code-block:: python
1209
+
1210
+ >>> import brainstate.environ as env
1211
+ >>> import jax.numpy as jnp
1212
+ >>>
1213
+ >>> # With 32-bit precision
1214
+ >>> env.set(precision=32)
1215
+ >>> counts = jnp.array([10, 20, 30], dtype=env.dutype())
1216
+ >>> print(counts.dtype) # uint32
1217
+ >>>
1218
+ >>> # With 16-bit precision
1219
+ >>> with env.context(precision=16):
1220
+ ... small_counts = jnp.array([1, 2, 3], dtype=env.dutype())
1221
+ ... print(small_counts.dtype) # uint16
1222
+
1223
+ See Also
1224
+ --------
1225
+ ditype : Default signed integer type
488
1226
  """
489
1227
  return _get_uint(_get_precision())
490
1228
 
491
1229
 
492
1230
  def dctype() -> DTypeLike:
493
1231
  """
494
- Default complex data type.
495
-
496
- This function returns the default complex data type based on the current precision.
497
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
498
- you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
1232
+ Get the default complex data type.
499
1233
 
500
- For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
501
-
502
- >>> import brainstate as brainstate
503
- >>> import numpy as np
504
- >>> with brainstate.environ.context(precision=32):
505
- ... a = np.zeros(1, dtype=brainstate.environ.dctype())
506
- >>> print(a.dtype)
507
- complex64
1234
+ This function returns the appropriate complex type based on
1235
+ the current precision setting.
508
1236
 
509
1237
  Returns
510
1238
  -------
511
- complex_dtype: DTypeLike
512
- The default complex data type.
1239
+ DTypeLike
1240
+ Default complex data type.
1241
+
1242
+ Examples
1243
+ --------
1244
+ .. code-block:: python
1245
+
1246
+ >>> import brainstate.environ as env
1247
+ >>> import jax.numpy as jnp
1248
+ >>>
1249
+ >>> # With 32-bit precision
1250
+ >>> env.set(precision=32)
1251
+ >>> z = jnp.array([1+2j, 3+4j], dtype=env.dctype())
1252
+ >>> print(z.dtype) # complex64
1253
+ >>>
1254
+ >>> # With 64-bit precision
1255
+ >>> with env.context(precision=64):
1256
+ ... w = jnp.array([5+6j], dtype=env.dctype())
1257
+ ... print(w.dtype) # complex128
1258
+
1259
+ Notes
1260
+ -----
1261
+ Complex128 is only available with 64-bit precision.
1262
+ All other precisions use complex64.
513
1263
  """
514
1264
  return _get_complex(_get_precision())
515
1265
 
516
1266
 
517
- def tolerance():
518
- if get_precision() == 64:
1267
+ def tolerance() -> jnp.ndarray:
1268
+ """
1269
+ Get numerical tolerance based on current precision.
1270
+
1271
+ This function returns an appropriate tolerance value for numerical
1272
+ comparisons based on the current precision setting.
1273
+
1274
+ Returns
1275
+ -------
1276
+ jnp.ndarray
1277
+ Tolerance value as a scalar array.
1278
+
1279
+ Examples
1280
+ --------
1281
+ .. code-block:: python
1282
+
1283
+ >>> import brainstate.environ as env
1284
+ >>> import jax.numpy as jnp
1285
+ >>>
1286
+ >>> # Different tolerances for different precisions
1287
+ >>> env.set(precision=64)
1288
+ >>> tol64 = env.tolerance()
1289
+ >>> print(f"64-bit tolerance: {tol64}") # 1e-12
1290
+ >>>
1291
+ >>> env.set(precision=32)
1292
+ >>> tol32 = env.tolerance()
1293
+ >>> print(f"32-bit tolerance: {tol32}") # 1e-5
1294
+ >>>
1295
+ >>> # Use in numerical comparisons
1296
+ >>> def are_close(a, b):
1297
+ ... return jnp.abs(a - b) < env.tolerance()
1298
+
1299
+ Notes
1300
+ -----
1301
+ Tolerance values:
1302
+ - 64-bit: 1e-12
1303
+ - 32-bit: 1e-5
1304
+ - 16-bit and below: 1e-2
1305
+ """
1306
+ precision = get_precision()
1307
+
1308
+ if precision == 64:
519
1309
  return jnp.array(1e-12, dtype=np.float64)
520
- elif get_precision() == 32:
1310
+ elif precision == 32:
521
1311
  return jnp.array(1e-5, dtype=np.float32)
522
1312
  else:
523
1313
  return jnp.array(1e-2, dtype=np.float16)
524
1314
 
525
1315
 
526
- def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
1316
+ def register_default_behavior(
1317
+ key: str,
1318
+ behavior: Callable[[Any], None],
1319
+ replace_if_exist: bool = False
1320
+ ) -> None:
527
1321
  """
528
- Register a default behavior for a specific global key parameter.
1322
+ Register a callback for environment parameter changes.
529
1323
 
530
- For example, you can register a default behavior for the key 'dt' by::
1324
+ This function allows you to register custom behaviors that are
1325
+ triggered whenever a specific environment parameter is modified.
531
1326
 
532
- >>> import brainstate as brainstate
533
- >>> def dt_behavior(dt):
534
- ... print(f'Set the default dt to {dt}.')
535
- ...
536
- >>> brainstate.environ.register_default_behavior('dt', dt_behavior)
1327
+ Parameters
1328
+ ----------
1329
+ key : str
1330
+ Environment parameter key to monitor.
1331
+ behavior : Callable[[Any], None]
1332
+ Callback function that receives the new value.
1333
+ replace_if_exist : bool, default=False
1334
+ Whether to replace existing callback for this key.
1335
+
1336
+ Raises
1337
+ ------
1338
+ TypeError
1339
+ If behavior is not callable.
1340
+ ValueError
1341
+ If key already has a registered behavior and replace_if_exist is False.
1342
+
1343
+ Examples
1344
+ --------
1345
+ Basic callback registration:
1346
+
1347
+ .. code-block:: python
1348
+
1349
+ >>> import brainstate.environ as env
1350
+ >>>
1351
+ >>> # Define a callback
1352
+ >>> def on_dt_change(new_dt):
1353
+ ... print(f"Time step changed to: {new_dt}")
1354
+ >>>
1355
+ >>> # Register the callback
1356
+ >>> env.register_default_behavior('dt', on_dt_change)
1357
+ >>>
1358
+ >>> # Callback is triggered on changes
1359
+ >>> env.set(dt=0.01) # Prints: Time step changed to: 0.01
1360
+ >>>
1361
+ >>> with env.context(dt=0.001): # Prints: Time step changed to: 0.001
1362
+ ... pass # Prints: Time step changed to: 0.01 (on exit)
1363
+
1364
+ Complex behavior with validation:
1365
+
1366
+ .. code-block:: python
1367
+
1368
+ >>> import brainstate.environ as env
1369
+ >>>
1370
+ >>> def validate_batch_size(size):
1371
+ ... if not isinstance(size, int) or size <= 0:
1372
+ ... raise ValueError(f"Invalid batch size: {size}")
1373
+ ... if size > 1024:
1374
+ ... print(f"Warning: Large batch size {size} may cause OOM")
1375
+ >>>
1376
+ >>> env.register_default_behavior('batch_size', validate_batch_size)
1377
+ >>>
1378
+ >>> # Valid setting
1379
+ >>> env.set(batch_size=32) # OK
1380
+ >>>
1381
+ >>> # Invalid setting
1382
+ >>> # env.set(batch_size=-1) # Raises ValueError
1383
+
1384
+ Replacing existing behavior:
1385
+
1386
+ .. code-block:: python
1387
+
1388
+ >>> import brainstate.environ as env
1389
+ >>>
1390
+ >>> def old_behavior(value):
1391
+ ... print(f"Old: {value}")
1392
+ >>>
1393
+ >>> def new_behavior(value):
1394
+ ... print(f"New: {value}")
1395
+ >>>
1396
+ >>> env.register_default_behavior('key', old_behavior)
1397
+ >>> env.register_default_behavior('key', new_behavior, replace_if_exist=True)
1398
+ >>>
1399
+ >>> env.set(key='test') # Prints: New: test
1400
+
1401
+ See Also
1402
+ --------
1403
+ unregister_default_behavior : Remove registered callbacks
1404
+ list_registered_behaviors : List all registered callbacks
1405
+ """
1406
+ if not isinstance(key, str):
1407
+ raise TypeError(f"Key must be a string, got {type(key)}")
537
1408
 
538
- Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
539
- `dt_behavior` will be called with
540
- `dt_behavior(0.1)`.
1409
+ if not callable(behavior):
1410
+ raise TypeError(f"Behavior must be callable, got {type(behavior)}")
541
1411
 
542
- >>> brainstate.environ.set(dt=0.1)
543
- Set the default dt to 0.1.
544
- >>> with brainstate.environ.context(dt=0.2):
545
- ... pass
546
- Set the default dt to 0.2.
547
- Set the default dt to 0.1.
1412
+ if key in _ENV_STATE.functions and not replace_if_exist:
1413
+ raise ValueError(
1414
+ f"Behavior for key '{key}' already registered. "
1415
+ f"Use replace_if_exist=True to override."
1416
+ )
548
1417
 
1418
+ _ENV_STATE.functions[key] = behavior
549
1419
 
550
- Args:
551
- key: str. The key to register.
552
- behavior: Callable. The behavior to register. It should be a callable.
553
- replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
554
1420
 
1421
+ def unregister_default_behavior(key: str) -> bool:
1422
+ """
1423
+ Remove a registered callback for an environment parameter.
1424
+
1425
+ Parameters
1426
+ ----------
1427
+ key : str
1428
+ Environment parameter key.
1429
+
1430
+ Returns
1431
+ -------
1432
+ bool
1433
+ True if a callback was removed, False if no callback existed.
1434
+
1435
+ Examples
1436
+ --------
1437
+ .. code-block:: python
1438
+
1439
+ >>> import brainstate.environ as env
1440
+ >>>
1441
+ >>> # Register a callback
1442
+ >>> def callback(value):
1443
+ ... print(f"Value: {value}")
1444
+ >>>
1445
+ >>> env.register_default_behavior('param', callback)
1446
+ >>>
1447
+ >>> # Remove the callback
1448
+ >>> removed = env.unregister_default_behavior('param')
1449
+ >>> print(f"Callback removed: {removed}") # True
1450
+ >>>
1451
+ >>> # No callback triggers now
1452
+ >>> env.set(param='test') # No output
1453
+ >>>
1454
+ >>> # Removing non-existent callback
1455
+ >>> removed = env.unregister_default_behavior('nonexistent')
1456
+ >>> print(f"Callback removed: {removed}") # False
1457
+ """
1458
+ if key in _ENV_STATE.functions:
1459
+ del _ENV_STATE.functions[key]
1460
+ return True
1461
+ return False
1462
+
1463
+
1464
+ def list_registered_behaviors() -> List[str]:
1465
+ """
1466
+ List all keys with registered callbacks.
1467
+
1468
+ Returns
1469
+ -------
1470
+ list of str
1471
+ Keys that have registered behavior callbacks.
1472
+
1473
+ Examples
1474
+ --------
1475
+ .. code-block:: python
1476
+
1477
+ >>> import brainstate.environ as env
1478
+ >>>
1479
+ >>> # Register some callbacks
1480
+ >>> env.register_default_behavior('param1', lambda x: None)
1481
+ >>> env.register_default_behavior('param2', lambda x: None)
1482
+ >>>
1483
+ >>> # List registered behaviors
1484
+ >>> behaviors = env.list_registered_behaviors()
1485
+ >>> print(f"Registered: {behaviors}") # ['param1', 'param2']
1486
+ >>>
1487
+ >>> # Check if specific behavior is registered
1488
+ >>> if 'dt' in behaviors:
1489
+ ... print("dt has a registered callback")
555
1490
  """
556
- assert isinstance(key, str), 'key must be a string.'
557
- assert callable(behavior), 'behavior must be a callable.'
558
- if not replace_if_exist:
559
- assert key not in DFAULT.functions, f'{key} has been registered.'
560
- DFAULT.functions[key] = behavior
1491
+ return list(_ENV_STATE.functions.keys())
561
1492
 
562
1493
 
563
- set(precision=32)
1494
+ # Initialize default precision on module load
1495
+ set(precision=DEFAULT_PRECISION)