brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/environ.py CHANGED
@@ -1,563 +1,1495 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import contextlib
19
- import dataclasses
20
- import functools
21
- import os
22
- import re
23
- import threading
24
- from collections import defaultdict
25
- from typing import Any, Callable, Dict, Hashable
26
-
27
- import numpy as np
28
- from jax import config, devices, numpy as jnp
29
- from jax.typing import DTypeLike
30
-
31
- from .mixin import Mode
32
-
33
- __all__ = [
34
- # functions for environment settings
35
- 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
36
- # functions for getting default behaviors
37
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
38
- # functions for default data types
39
- 'dftype', 'ditype', 'dutype', 'dctype',
40
- # others
41
- 'tolerance', 'register_default_behavior',
42
- ]
43
-
44
- # Default, there are several shared arguments in the global context.
45
- I = 'i' # the index of the current computation.
46
- T = 't' # the current time of the current computation.
47
- JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
48
- FIT = 'fit' # whether to fit the model.
49
-
50
-
51
- @dataclasses.dataclass
52
- class DefaultContext(threading.local):
53
- # default environment settings
54
- settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
55
- # current environment settings
56
- contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
57
- # environment functions
58
- functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
59
-
60
-
61
- DFAULT = DefaultContext()
62
- _NOT_PROVIDE = object()
63
-
64
-
65
- @contextlib.contextmanager
66
- def context(**kwargs):
67
- r"""
68
- Context-manager that sets a computing environment for brain dynamics computation.
69
-
70
- In BrainPy, there are several basic computation settings when constructing models,
71
- including ``mode`` for controlling model computing behavior, ``dt`` for numerical
72
- integration, ``int_`` for integer precision, and ``float_`` for floating precision.
73
- :py:class:`~.environment`` provides a context for model construction and
74
- computation. In this temporal environment, models are constructed with the given
75
- ``mode``, ``dt``, ``int_``, etc., environment settings.
76
-
77
- For instance::
78
-
79
- >>> import brainstate as brainstate
80
- >>> with brainstate.environ.context(dt=0.1) as env:
81
- ... dt = brainstate.environ.get('dt')
82
- ... print(env)
83
-
84
- """
85
- if 'platform' in kwargs:
86
- raise ValueError('\n'
87
- 'Cannot set platform in "context" environment. \n'
88
- 'You should set platform in the global environment by "set_platform()" or "set()".')
89
- if 'host_device_count' in kwargs:
90
- raise ValueError('Cannot set host_device_count in environment context. '
91
- 'Please use set_host_device_count() or set() for the global setting.')
92
-
93
- if 'precision' in kwargs:
94
- last_precision = _get_precision()
95
- _set_jax_precision(kwargs['precision'])
96
-
97
- try:
98
- for k, v in kwargs.items():
99
-
100
- # update the current environment
101
- DFAULT.contexts[k].append(v)
102
-
103
- # restore the environment functions
104
- if k in DFAULT.functions:
105
- DFAULT.functions[k](v)
106
-
107
- # yield the current all environment information
108
- yield all()
109
- finally:
110
-
111
- for k, v in kwargs.items():
112
-
113
- # restore the current environment
114
- DFAULT.contexts[k].pop()
115
-
116
- # restore the environment functions
117
- if k in DFAULT.functions:
118
- DFAULT.functions[k](get(k))
119
-
120
- if 'precision' in kwargs:
121
- _set_jax_precision(last_precision)
122
-
123
-
124
- def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
125
- """
126
- Get one of the default computation environment.
127
-
128
- Returns
129
- -------
130
- item: Any
131
- The default computation environment.
132
- """
133
- if key == 'platform':
134
- return get_platform()
135
-
136
- if key == 'host_device_count':
137
- return get_host_device_count()
138
-
139
- if key in DFAULT.contexts:
140
- if len(DFAULT.contexts[key]) > 0:
141
- return DFAULT.contexts[key][-1]
142
- if key in DFAULT.settings:
143
- return DFAULT.settings[key]
144
-
145
- if default is _NOT_PROVIDE:
146
- if desc is not None:
147
- raise KeyError(
148
- f"'{key}' is not found in the context. \n"
149
- f"You can set it by `brainstate.share.context({key}=value)` "
150
- f"locally or `brainstate.share.set({key}=value)` globally. \n"
151
- f"Description: {desc}"
152
- )
153
- else:
154
- raise KeyError(
155
- f"'{key}' is not found in the context. \n"
156
- f"You can set it by `brainstate.share.context({key}=value)` "
157
- f"locally or `brainstate.share.set({key}=value)` globally."
158
- )
159
- return default
160
-
161
-
162
- def all() -> dict:
163
- """
164
- Get all the current default computation environment.
165
-
166
- Returns
167
- -------
168
- r: dict
169
- The current default computation environment.
170
- """
171
- r = dict()
172
- for k, v in DFAULT.contexts.items():
173
- if v:
174
- r[k] = v[-1]
175
- for k, v in DFAULT.settings.items():
176
- if k not in r:
177
- r[k] = v
178
- return r
179
-
180
-
181
- def get_dt():
182
- """Get the numerical integrator precision.
183
-
184
- Returns
185
- -------
186
- dt : float
187
- Numerical integration precision.
188
- """
189
- return get('dt')
190
-
191
-
192
- def get_mode() -> Mode:
193
- """Get the default computing mode.
194
-
195
- References
196
- ----------
197
- mode: Mode
198
- The default computing mode.
199
- """
200
- return get('mode')
201
-
202
-
203
- def get_platform() -> str:
204
- """Get the computing platform.
205
-
206
- Returns
207
- -------
208
- platform: str
209
- Either 'cpu', 'gpu' or 'tpu'.
210
- """
211
- return devices()[0].platform
212
-
213
-
214
- def get_host_device_count():
215
- """
216
- Get the number of host devices.
217
-
218
- Returns
219
- -------
220
- n: int
221
- The number of host devices.
222
- """
223
- xla_flags = os.getenv("XLA_FLAGS", "")
224
- match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
225
- return int(match.group(1)) if match else 1
226
-
227
-
228
- def _get_precision() -> int | str:
229
- """
230
- Get the default precision.
231
-
232
- Returns
233
- -------
234
- precision: int
235
- The default precision.
236
- """
237
- return get('precision')
238
-
239
-
240
- def get_precision() -> int:
241
- """
242
- Get the default precision.
243
-
244
- Returns
245
- -------
246
- precision: int
247
- The default precision.
248
- """
249
- precision = get('precision')
250
- if precision == 'bf16':
251
- return 16
252
- if isinstance(precision, int):
253
- return precision
254
- if isinstance(precision, str):
255
- return int(precision)
256
- raise ValueError(f'Unsupported precision: {precision}')
257
-
258
-
259
- def set(
260
- platform: str = None,
261
- host_device_count: int = None,
262
- precision: int | str = None,
263
- mode: Mode = None,
264
- **kwargs
265
- ):
266
- """
267
- Set the global default computation environment.
268
-
269
-
270
-
271
- Args:
272
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
273
- host_device_count: int. The number of host devices.
274
- precision: int, str. The default precision.
275
- mode: Mode. The computing mode.
276
- **kwargs: dict. Other environment settings.
277
- """
278
- if platform is not None:
279
- set_platform(platform)
280
- if host_device_count is not None:
281
- set_host_device_count(host_device_count)
282
- if precision is not None:
283
- _set_jax_precision(precision)
284
- kwargs['precision'] = precision
285
- if mode is not None:
286
- assert isinstance(mode, Mode), 'mode must be a Mode instance.'
287
- kwargs['mode'] = mode
288
-
289
- # set default environment
290
- DFAULT.settings.update(kwargs)
291
-
292
- # update the environment functions
293
- for k, v in kwargs.items():
294
- if k in DFAULT.functions:
295
- DFAULT.functions[k](v)
296
-
297
-
298
- def set_host_device_count(n):
299
- """
300
- By default, XLA considers all CPU cores as one device. This utility tells XLA
301
- that there are `n` host (CPU) devices available to use. As a consequence, this
302
- allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
303
-
304
- .. note:: This utility only takes effect at the beginning of your program.
305
- Under the hood, this sets the environment variable
306
- `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
307
- `[num_device]` is the desired number of CPU devices `n`.
308
-
309
- .. warning:: Our understanding of the side effects of using the
310
- `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
311
- observe some strange phenomenon when using this utility, please let us
312
- know through our issue or forum page. More information is available in this
313
- `JAX issue <https://github.com/google/jax/issues/1408>`_.
314
-
315
- :param int n: number of devices to use.
316
- """
317
- xla_flags = os.getenv("XLA_FLAGS", "")
318
- xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
319
- os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
320
-
321
- # update the environment functions
322
- if 'host_device_count' in DFAULT.functions:
323
- DFAULT.functions['host_device_count'](n)
324
-
325
-
326
- def set_platform(platform: str):
327
- """
328
- Changes platform to CPU, GPU, or TPU. This utility only takes
329
- effect at the beginning of your program.
330
-
331
- Args:
332
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
333
-
334
- Raises:
335
- ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
336
- """
337
- assert platform in ['cpu', 'gpu', 'tpu']
338
- config.update("jax_platform_name", platform)
339
-
340
- # update the environment functions
341
- if 'platform' in DFAULT.functions:
342
- DFAULT.functions['platform'](platform)
343
-
344
-
345
- def _set_jax_precision(precision: int | str):
346
- """
347
- Set the default precision.
348
-
349
- Args:
350
- precision: int. The default precision.
351
- """
352
- # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
353
- if precision in [64, '64']:
354
- config.update("jax_enable_x64", True)
355
- else:
356
- config.update("jax_enable_x64", False)
357
-
358
-
359
- @functools.lru_cache()
360
- def _get_uint(precision: int):
361
- if precision in [64, '64']:
362
- return np.uint64
363
- elif precision in [32, '32']:
364
- return np.uint32
365
- elif precision in [16, '16', 'bf16']:
366
- return np.uint16
367
- elif precision in [8, '8']:
368
- return np.uint8
369
- else:
370
- raise ValueError(f'Unsupported precision: {precision}')
371
-
372
-
373
- @functools.lru_cache()
374
- def _get_int(precision: int):
375
- if precision in [64, '64']:
376
- return np.int64
377
- elif precision in [32, '32']:
378
- return np.int32
379
- elif precision in [16, '16', 'bf16']:
380
- return np.int16
381
- elif precision in [8, '8']:
382
- return np.int8
383
- else:
384
- raise ValueError(f'Unsupported precision: {precision}')
385
-
386
-
387
- @functools.lru_cache()
388
- def _get_float(precision: int):
389
- if precision in [64, '64']:
390
- return np.float64
391
- elif precision in [32, '32']:
392
- return np.float32
393
- elif precision in [16, '16']:
394
- return np.float16
395
- elif precision in ['bf16']:
396
- return jnp.bfloat16
397
- elif precision in [8, '8']:
398
- return jnp.float8_e5m2
399
- else:
400
- raise ValueError(f'Unsupported precision: {precision}')
401
-
402
-
403
- @functools.lru_cache()
404
- def _get_complex(precision: int):
405
- if precision in [64, '64']:
406
- return np.complex128
407
- elif precision in [32, '32']:
408
- return np.complex64
409
- elif precision in [16, '16', 'bf16']:
410
- return np.complex64
411
- elif precision in [8, '8']:
412
- return np.complex64
413
- else:
414
- raise ValueError(f'Unsupported precision: {precision}')
415
-
416
-
417
- def dftype() -> DTypeLike:
418
- """
419
- Default floating data type.
420
-
421
- This function returns the default floating data type based on the current precision.
422
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
423
- you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
424
-
425
- For example, if the precision is set to 32, the default floating data type is ``np.float32``.
426
-
427
- >>> import brainstate as brainstate
428
- >>> import numpy as np
429
- >>> with brainstate.environ.context(precision=32):
430
- ... a = np.zeros(1, dtype=brainstate.environ.dftype())
431
- >>> print(a.dtype)
432
-
433
- Returns
434
- -------
435
- float_dtype: DTypeLike
436
- The default floating data type.
437
- """
438
- return _get_float(_get_precision())
439
-
440
-
441
- def ditype() -> DTypeLike:
442
- """
443
- Default integer data type.
444
-
445
- This function returns the default integer data type based on the current precision.
446
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
447
- you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
448
-
449
- For example, if the precision is set to 32, the default integer data type is ``np.int32``.
450
-
451
- >>> import brainstate as brainstate
452
- >>> import numpy as np
453
- >>> with brainstate.environ.context(precision=32):
454
- ... a = np.zeros(1, dtype=brainstate.environ.ditype())
455
- >>> print(a.dtype)
456
- int32
457
-
458
- Returns
459
- -------
460
- int_dtype: DTypeLike
461
- The default integer data type.
462
- """
463
- return _get_int(_get_precision())
464
-
465
-
466
- def dutype() -> DTypeLike:
467
- """
468
- Default unsigned integer data type.
469
-
470
- This function returns the default unsigned integer data type based on the current precision.
471
- If you want the data dtype is changed with the setting of the precision
472
- by ``brainstate.environ.set(precision)``, you can use this function to get the default
473
- unsigned integer data type, and create the data by using ``dtype=dutype()``.
474
-
475
- For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
476
-
477
- >>> import brainstate as brainstate
478
- >>> import numpy as np
479
- >>> with brainstate.environ.context(precision=32):
480
- ... a = np.zeros(1, dtype=brainstate.environ.dutype())
481
- >>> print(a.dtype)
482
- uint32
483
-
484
- Returns
485
- -------
486
- uint_dtype: DTypeLike
487
- The default unsigned integer data type.
488
- """
489
- return _get_uint(_get_precision())
490
-
491
-
492
- def dctype() -> DTypeLike:
493
- """
494
- Default complex data type.
495
-
496
- This function returns the default complex data type based on the current precision.
497
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
498
- you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
499
-
500
- For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
501
-
502
- >>> import brainstate as brainstate
503
- >>> import numpy as np
504
- >>> with brainstate.environ.context(precision=32):
505
- ... a = np.zeros(1, dtype=brainstate.environ.dctype())
506
- >>> print(a.dtype)
507
- complex64
508
-
509
- Returns
510
- -------
511
- complex_dtype: DTypeLike
512
- The default complex data type.
513
- """
514
- return _get_complex(_get_precision())
515
-
516
-
517
- def tolerance():
518
- if get_precision() == 64:
519
- return jnp.array(1e-12, dtype=np.float64)
520
- elif get_precision() == 32:
521
- return jnp.array(1e-5, dtype=np.float32)
522
- else:
523
- return jnp.array(1e-2, dtype=np.float16)
524
-
525
-
526
- def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
527
- """
528
- Register a default behavior for a specific global key parameter.
529
-
530
- For example, you can register a default behavior for the key 'dt' by::
531
-
532
- >>> import brainstate as brainstate
533
- >>> def dt_behavior(dt):
534
- ... print(f'Set the default dt to {dt}.')
535
- ...
536
- >>> brainstate.environ.register_default_behavior('dt', dt_behavior)
537
-
538
- Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
539
- `dt_behavior` will be called with
540
- `dt_behavior(0.1)`.
541
-
542
- >>> brainstate.environ.set(dt=0.1)
543
- Set the default dt to 0.1.
544
- >>> with brainstate.environ.context(dt=0.2):
545
- ... pass
546
- Set the default dt to 0.2.
547
- Set the default dt to 0.1.
548
-
549
-
550
- Args:
551
- key: str. The key to register.
552
- behavior: Callable. The behavior to register. It should be a callable.
553
- replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
554
-
555
- """
556
- assert isinstance(key, str), 'key must be a string.'
557
- assert callable(behavior), 'behavior must be a callable.'
558
- if not replace_if_exist:
559
- assert key not in DFAULT.functions, f'{key} has been registered.'
560
- DFAULT.functions[key] = behavior
561
-
562
-
563
- set(precision=32)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Environment configuration and context management for BrainState.
18
+
19
+ This module provides comprehensive functionality for managing computational
20
+ environments, including platform selection, precision control, mode setting,
21
+ and context-based configuration management. It enables flexible configuration
22
+ of JAX-based computations with thread-safe context switching.
23
+
24
+ The module supports:
25
+ - Platform configuration (CPU, GPU, TPU)
26
+ - Precision control (8, 16, 32, 64 bit and bfloat16)
27
+ - Computation mode management
28
+ - Context-based temporary settings
29
+ - Default data type management
30
+ - Custom behavior registration
31
+
32
+ Examples
33
+ --------
34
+ Global environment configuration:
35
+
36
+ .. code-block:: python
37
+
38
+ >>> import brainstate as bs
39
+ >>> import brainstate.environ as env
40
+ >>>
41
+ >>> # Set global precision to 32-bit
42
+ >>> env.set(precision=32, dt=0.01, mode=bs.mixin.Training())
43
+ >>>
44
+ >>> # Get current settings
45
+ >>> print(env.get('precision')) # 32
46
+ >>> print(env.get('dt')) # 0.01
47
+
48
+ Context-based temporary settings:
49
+
50
+ .. code-block:: python
51
+
52
+ >>> import brainstate.environ as env
53
+ >>>
54
+ >>> # Temporarily change precision
55
+ >>> with env.context(precision=64, dt=0.001):
56
+ ... high_precision_result = compute_something()
57
+ ... print(env.get('precision')) # 64
58
+ >>> print(env.get('precision')) # Back to 32
59
+ """
60
+
61
+ import contextlib
62
+ import dataclasses
63
+ import functools
64
+ import os
65
+ import re
66
+ import threading
67
+ import warnings
68
+ from collections import defaultdict
69
+ from typing import Any, Callable, Dict, Hashable, Optional, Union, ContextManager, List
70
+
71
+ import brainunit as u
72
+ import numpy as np
73
+ from jax import config, devices, numpy as jnp
74
+ from jax.typing import DTypeLike
75
+
76
+ __all__ = [
77
+ # Core environment management
78
+ 'set',
79
+ 'get',
80
+ 'all',
81
+ 'pop',
82
+ 'context',
83
+ 'reset',
84
+
85
+ # Platform and device management
86
+ 'set_platform',
87
+ 'get_platform',
88
+ 'set_host_device_count',
89
+ 'get_host_device_count',
90
+
91
+ # Precision and data type management
92
+ 'set_precision',
93
+ 'get_precision',
94
+ 'dftype',
95
+ 'ditype',
96
+ 'dutype',
97
+ 'dctype',
98
+
99
+ # Mode and computation settings
100
+ 'get_dt',
101
+
102
+ # Utility functions
103
+ 'tolerance',
104
+ 'register_default_behavior',
105
+ 'unregister_default_behavior',
106
+ 'list_registered_behaviors',
107
+
108
+ # Constants
109
+ 'DEFAULT_PRECISION',
110
+ 'SUPPORTED_PLATFORMS',
111
+ 'SUPPORTED_PRECISIONS',
112
+
113
+ # Names
114
+ 'I',
115
+ 'T',
116
+ 'DT',
117
+ 'PRECISION',
118
+ 'PLATFORM',
119
+ 'HOST_DEVICE_COUNT',
120
+ 'JIT_ERROR_CHECK',
121
+ 'FIT',
122
+ ]
123
+
124
+ # Type definitions
125
+ # T = TypeVar('T')
126
+ PrecisionType = Union[int, str]
127
+ PlatformType = str
128
+
129
+ # Constants for environment keys
130
+ I = 'i' # Index of the current computation
131
+ T = 't' # Current time of the computation
132
+ DT = 'dt' # Time step for numerical integration
133
+ PRECISION = 'precision' # Numerical precision
134
+ PLATFORM = 'platform' # Computing platform
135
+ HOST_DEVICE_COUNT = 'host_device_count' # Number of host devices
136
+ JIT_ERROR_CHECK = 'jit_error_check' # JIT error checking flag
137
+ FIT = 'fit' # Model fitting flag
138
+
139
+ # Default values
140
+ DEFAULT_PRECISION = 32
141
+ SUPPORTED_PLATFORMS = ('cpu', 'gpu', 'tpu')
142
+ SUPPORTED_PRECISIONS = (8, 16, 32, 64, 'bf16')
143
+
144
+ # Sentinel value for missing arguments
145
+ _NOT_PROVIDED = object()
146
+
147
+
148
+ @dataclasses.dataclass
149
+ class EnvironmentState(threading.local):
150
+ """
151
+ Thread-local storage for environment configuration.
152
+
153
+ This class maintains separate configuration states for different threads,
154
+ ensuring thread-safe environment management in concurrent applications.
155
+
156
+ Attributes
157
+ ----------
158
+ settings : Dict[Hashable, Any]
159
+ Global default environment settings.
160
+ contexts : defaultdict[Hashable, List[Any]]
161
+ Stack of context-specific settings for nested contexts.
162
+ functions : Dict[Hashable, Callable]
163
+ Registered callback functions for environment changes.
164
+ locks : Dict[str, threading.Lock]
165
+ Thread locks for synchronized access to critical sections.
166
+ """
167
+ settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
168
+ contexts: defaultdict[Hashable, List[Any]] = dataclasses.field(default_factory=lambda: defaultdict(list))
169
+ functions: Dict[Hashable, Callable] = dataclasses.field(default_factory=dict)
170
+ locks: Dict[str, threading.Lock] = dataclasses.field(default_factory=lambda: defaultdict(threading.Lock))
171
+
172
+ def __post_init__(self):
173
+ """Initialize with default settings."""
174
+ # Set default precision if not already set
175
+ if PRECISION not in self.settings:
176
+ self.settings[PRECISION] = DEFAULT_PRECISION
177
+
178
+
179
+ # Global environment state
180
+ _ENV_STATE = EnvironmentState()
181
+
182
+
183
+ def reset() -> None:
184
+ """
185
+ Reset the environment to default settings.
186
+
187
+ This function clears all custom settings and restores the environment
188
+ to its initial state. Useful for testing or when starting fresh.
189
+
190
+ Examples
191
+ --------
192
+ .. code-block:: python
193
+
194
+ >>> import brainstate.environ as env
195
+ >>>
196
+ >>> # Set custom values
197
+ >>> env.set(dt=0.1, custom_param='value')
198
+ >>> print(env.get('custom_param')) # 'value'
199
+ >>>
200
+ >>> # Reset to defaults
201
+ >>> env.reset()
202
+ >>> print(env.get('custom_param', default=None)) # None
203
+
204
+ Notes
205
+ -----
206
+ This operation cannot be undone. All custom settings will be lost.
207
+ """
208
+ global _ENV_STATE
209
+ _ENV_STATE = EnvironmentState()
210
+ # Re-apply default precision
211
+ _set_jax_precision(DEFAULT_PRECISION)
212
+
213
+ warnings.warn(
214
+ "Environment has been reset to default settings. "
215
+ "All custom configurations have been cleared.",
216
+ UserWarning
217
+ )
218
+
219
+
220
+ @contextlib.contextmanager
221
+ def context(**kwargs) -> ContextManager[Dict[str, Any]]:
222
+ """
223
+ Context manager for temporary environment settings.
224
+
225
+ This context manager allows you to temporarily modify environment settings
226
+ within a specific scope. Settings are automatically restored when exiting
227
+ the context, even if an exception occurs.
228
+
229
+ Parameters
230
+ ----------
231
+ **kwargs
232
+ Environment settings to apply within the context.
233
+ Common parameters include:
234
+
235
+ - precision : int or str.
236
+ Numerical precision (8, 16, 32, 64, or 'bf16')
237
+ - dt : float.
238
+ Time step for numerical integration
239
+ - mode : :class:`Mode`.
240
+ Computation mode instance
241
+ - Any custom parameters registered via register_default_behavior
242
+
243
+ Yields
244
+ ------
245
+ dict
246
+ Current environment settings within the context.
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If attempting to set platform or host_device_count in context
252
+ (these must be set globally).
253
+ TypeError
254
+ If invalid parameter types are provided.
255
+
256
+ Examples
257
+ --------
258
+ Basic usage with precision control:
259
+
260
+ .. code-block:: python
261
+
262
+ >>> import brainstate.environ as env
263
+ >>>
264
+ >>> # Set global precision
265
+ >>> env.set(precision=32)
266
+ >>>
267
+ >>> # Temporarily use higher precision
268
+ >>> with env.context(precision=64) as ctx:
269
+ ... print(f"Precision in context: {env.get('precision')}") # 64
270
+ ... print(f"Float type: {env.dftype()}") # float64
271
+ >>>
272
+ >>> print(f"Precision after context: {env.get('precision')}") # 32
273
+
274
+ Nested contexts:
275
+
276
+ .. code-block:: python
277
+
278
+ >>> import brainstate.environ as env
279
+ >>>
280
+ >>> with env.context(dt=0.1) as ctx1:
281
+ ... print(f"dt = {env.get('dt')}") # 0.1
282
+ ...
283
+ ... with env.context(dt=0.01) as ctx2:
284
+ ... print(f"dt = {env.get('dt')}") # 0.01
285
+ ...
286
+ ... print(f"dt = {env.get('dt')}") # 0.1
287
+
288
+ Error handling in context:
289
+
290
+ .. code-block:: python
291
+
292
+ >>> import brainstate.environ as env
293
+ >>>
294
+ >>> env.set(value=10)
295
+ >>> try:
296
+ ... with env.context(value=20):
297
+ ... print(env.get('value')) # 20
298
+ ... raise ValueError("Something went wrong")
299
+ ... except ValueError:
300
+ ... pass
301
+ >>>
302
+ >>> print(env.get('value')) # 10 (restored)
303
+
304
+ Notes
305
+ -----
306
+ - Platform and host_device_count cannot be set in context
307
+ - Contexts can be nested arbitrarily deep
308
+ - Settings are restored in reverse order when exiting
309
+ - Thread-safe: each thread maintains its own context stack
310
+ """
311
+ # Validate restricted parameters
312
+ if PLATFORM in kwargs:
313
+ raise ValueError(
314
+ f"Cannot set '{PLATFORM}' in context. "
315
+ f"Use set_platform() or set() for global configuration."
316
+ )
317
+ if HOST_DEVICE_COUNT in kwargs:
318
+ raise ValueError(
319
+ f"Cannot set '{HOST_DEVICE_COUNT}' in context. "
320
+ f"Use set_host_device_count() or set() for global configuration."
321
+ )
322
+
323
+ # Handle precision changes
324
+ original_precision = None
325
+ if PRECISION in kwargs:
326
+ original_precision = _get_precision()
327
+ _validate_precision(kwargs[PRECISION])
328
+ _set_jax_precision(kwargs[PRECISION])
329
+
330
+ try:
331
+ # Push new values onto context stacks
332
+ for key, value in kwargs.items():
333
+ with _ENV_STATE.locks[key]:
334
+ _ENV_STATE.contexts[key].append(value)
335
+
336
+ # Trigger registered callbacks
337
+ if key in _ENV_STATE.functions:
338
+ try:
339
+ _ENV_STATE.functions[key](value)
340
+ except Exception as e:
341
+ warnings.warn(
342
+ f"Callback for '{key}' raised an exception: {e}",
343
+ RuntimeWarning
344
+ )
345
+
346
+ # Yield current environment state
347
+ yield all()
348
+
349
+ finally:
350
+ # Restore previous values
351
+ for key in kwargs:
352
+ with _ENV_STATE.locks[key]:
353
+ if _ENV_STATE.contexts[key]:
354
+ _ENV_STATE.contexts[key].pop()
355
+
356
+ # Restore callbacks with previous value
357
+ if key in _ENV_STATE.functions:
358
+ try:
359
+ prev_value = get(key, default=None)
360
+ if prev_value is not None:
361
+ _ENV_STATE.functions[key](prev_value)
362
+ except Exception as e:
363
+ warnings.warn(
364
+ f"Callback restoration for '{key}' raised: {e}",
365
+ RuntimeWarning
366
+ )
367
+
368
+ # Restore precision if it was changed
369
+ if original_precision is not None:
370
+ _set_jax_precision(original_precision)
371
+
372
+
373
+ def get(key: str, default: Any = _NOT_PROVIDED, desc: Optional[str] = None) -> Any:
374
+ """
375
+ Get a value from the current environment.
376
+
377
+ This function retrieves values from the environment, checking first in
378
+ the context stack, then in global settings. Special handling is provided
379
+ for platform and device count parameters.
380
+
381
+ Parameters
382
+ ----------
383
+ key : str
384
+ The environment key to retrieve.
385
+ default : Any, optional
386
+ Default value to return if key is not found.
387
+ If not provided, raises KeyError for missing keys.
388
+ desc : str, optional
389
+ Description of the parameter for error messages.
390
+
391
+ Returns
392
+ -------
393
+ Any
394
+ The value associated with the key.
395
+
396
+ Raises
397
+ ------
398
+ KeyError
399
+ If key is not found and no default is provided.
400
+
401
+ Examples
402
+ --------
403
+ Basic retrieval:
404
+
405
+ .. code-block:: python
406
+
407
+ >>> import brainstate.environ as env
408
+ >>>
409
+ >>> env.set(learning_rate=0.001)
410
+ >>> lr = env.get('learning_rate')
411
+ >>> print(f"Learning rate: {lr}") # 0.001
412
+
413
+ With default value:
414
+
415
+ .. code-block:: python
416
+
417
+ >>> import brainstate.environ as env
418
+ >>>
419
+ >>> # Get with default
420
+ >>> batch_size = env.get('batch_size', default=32)
421
+ >>> print(f"Batch size: {batch_size}") # 32
422
+
423
+ Context-aware retrieval:
424
+
425
+ .. code-block:: python
426
+
427
+ >>> import brainstate.environ as env
428
+ >>>
429
+ >>> env.set(temperature=1.0)
430
+ >>> print(env.get('temperature')) # 1.0
431
+ >>>
432
+ >>> with env.context(temperature=0.5):
433
+ ... print(env.get('temperature')) # 0.5
434
+ >>>
435
+ >>> print(env.get('temperature')) # 1.0
436
+
437
+ Notes
438
+ -----
439
+ Special keys 'platform' and 'host_device_count' are handled separately
440
+ and retrieve system-level information.
441
+ """
442
+ # Special cases for platform-specific parameters
443
+ if key == PLATFORM:
444
+ return get_platform()
445
+ if key == HOST_DEVICE_COUNT:
446
+ return get_host_device_count()
447
+
448
+ # Check context stack first (most recent value)
449
+ with _ENV_STATE.locks[key]:
450
+ if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
451
+ return _ENV_STATE.contexts[key][-1]
452
+
453
+ # Check global settings
454
+ if key in _ENV_STATE.settings:
455
+ return _ENV_STATE.settings[key]
456
+
457
+ # Handle missing key
458
+ if default is _NOT_PROVIDED:
459
+ error_msg = f"Key '{key}' not found in environment."
460
+ if desc:
461
+ error_msg += f" Description: {desc}"
462
+ error_msg += (
463
+ f"\nSet it using:\n"
464
+ f" - env.set({key}=value) for global setting\n"
465
+ f" - env.context({key}=value) for temporary setting"
466
+ )
467
+ raise KeyError(error_msg)
468
+
469
+ return default
470
+
471
+
472
+ def all() -> Dict[str, Any]:
473
+ """
474
+ Get all current environment settings.
475
+
476
+ This function returns a dictionary containing all active environment
477
+ settings, with context values taking precedence over global settings.
478
+
479
+ Returns
480
+ -------
481
+ dict
482
+ Dictionary of all current environment settings.
483
+
484
+ Examples
485
+ --------
486
+ .. code-block:: python
487
+
488
+ >>> import brainstate.environ as env
489
+ >>>
490
+ >>> # Set various parameters
491
+ >>> env.set(precision=32, dt=0.01, debug=True)
492
+ >>>
493
+ >>> # Get all settings
494
+ >>> settings = env.all()
495
+ >>> print(settings)
496
+ {'precision': 32, 'dt': 0.01, 'debug': True}
497
+
498
+ >>> # Context overrides
499
+ >>> with env.context(precision=64, new_param='test'):
500
+ ... settings = env.all()
501
+ ... print(settings['precision']) # 64
502
+ ... print(settings['new_param']) # 'test'
503
+
504
+ Notes
505
+ -----
506
+ The returned dictionary is a snapshot and modifying it does not
507
+ affect the environment settings.
508
+ """
509
+ result = {}
510
+
511
+ # Add global settings
512
+ result.update(_ENV_STATE.settings)
513
+
514
+ # Override with context values (most recent)
515
+ for key, values in _ENV_STATE.contexts.items():
516
+ if values:
517
+ result[key] = values[-1]
518
+
519
+ return result
520
+
521
+
522
+ def pop(key: str, default: Any = _NOT_PROVIDED) -> Any:
523
+ """
524
+ Remove and return a value from the global environment.
525
+
526
+ This function removes a key from the global environment settings and
527
+ returns its value. If the key is not found, it returns the default
528
+ value if provided, or raises KeyError.
529
+
530
+ Note that this function only affects global settings, not context values.
531
+ Keys in active contexts are not affected.
532
+
533
+ Parameters
534
+ ----------
535
+ key : str
536
+ The environment key to remove.
537
+ default : Any, optional
538
+ Default value to return if key is not found.
539
+ If not provided, raises KeyError for missing keys.
540
+
541
+ Returns
542
+ -------
543
+ Any
544
+ The value that was removed from the environment.
545
+
546
+ Raises
547
+ ------
548
+ KeyError
549
+ If key is not found and no default is provided.
550
+ ValueError
551
+ If attempting to pop a key that is currently in a context.
552
+
553
+ Examples
554
+ --------
555
+ Basic usage:
556
+
557
+ .. code-block:: python
558
+
559
+ >>> import brainstate.environ as env
560
+ >>>
561
+ >>> # Set a value
562
+ >>> env.set(temp_param='temporary')
563
+ >>> print(env.get('temp_param')) # 'temporary'
564
+ >>>
565
+ >>> # Pop the value
566
+ >>> value = env.pop('temp_param')
567
+ >>> print(value) # 'temporary'
568
+ >>>
569
+ >>> # Value is now gone
570
+ >>> env.get('temp_param', default=None) # None
571
+
572
+ With default value:
573
+
574
+ .. code-block:: python
575
+
576
+ >>> import brainstate.environ as env
577
+ >>>
578
+ >>> # Pop non-existent key with default
579
+ >>> value = env.pop('nonexistent', default='default_value')
580
+ >>> print(value) # 'default_value'
581
+
582
+ Pop multiple values:
583
+
584
+ .. code-block:: python
585
+
586
+ >>> import brainstate.environ as env
587
+ >>>
588
+ >>> # Set multiple values
589
+ >>> env.set(param1='value1', param2='value2', param3='value3')
590
+ >>>
591
+ >>> # Pop them one by one
592
+ >>> v1 = env.pop('param1')
593
+ >>> v2 = env.pop('param2')
594
+ >>>
595
+ >>> # param3 still exists
596
+ >>> print(env.get('param3')) # 'value3'
597
+
598
+ Context protection:
599
+
600
+ .. code-block:: python
601
+
602
+ >>> import brainstate.environ as env
603
+ >>>
604
+ >>> env.set(protected='global_value')
605
+ >>>
606
+ >>> with env.context(protected='context_value'):
607
+ ... # Cannot pop a key that's in active context
608
+ ... try:
609
+ ... env.pop('protected')
610
+ ... except ValueError as e:
611
+ ... print("Cannot pop key in active context")
612
+
613
+ Notes
614
+ -----
615
+ - This function only removes keys from global settings
616
+ - Keys that are currently overridden in active contexts cannot be popped
617
+ - Special keys like 'platform' and 'host_device_count' can be popped but
618
+ their system-level values remain accessible through get_platform() etc.
619
+ - Registered callbacks are NOT triggered when popping values
620
+ """
621
+ # Check if key is currently in any active context
622
+ if key in _ENV_STATE.contexts and _ENV_STATE.contexts[key]:
623
+ raise ValueError(
624
+ f"Cannot pop key '{key}' while it is active in a context. "
625
+ f"The key is currently overridden in {len(_ENV_STATE.contexts[key])} context(s)."
626
+ )
627
+
628
+ # Check if key exists in global settings
629
+ if key in _ENV_STATE.settings:
630
+ # Remove and return the value
631
+ value = _ENV_STATE.settings.pop(key)
632
+
633
+ # Note: We don't trigger callbacks here as this is a removal operation
634
+ # If needed, users can register callbacks for removal separately
635
+
636
+ return value
637
+
638
+ # Key not found, handle default
639
+ if default is _NOT_PROVIDED:
640
+ raise KeyError(f"Key '{key}' not found in global environment settings.")
641
+
642
+ return default
643
+
644
+
645
+ def set(
646
+ platform: Optional[PlatformType] = None,
647
+ host_device_count: Optional[int] = None,
648
+ precision: Optional[PrecisionType] = None,
649
+ dt: Optional[float] = None,
650
+ **kwargs
651
+ ) -> None:
652
+ """
653
+ Set global environment configuration.
654
+
655
+ This function sets persistent global environment settings that remain
656
+ active until explicitly changed or the program terminates.
657
+
658
+ Parameters
659
+ ----------
660
+ platform : str, optional
661
+ Computing platform ('cpu', 'gpu', or 'tpu').
662
+ host_device_count : int, optional
663
+ Number of host devices for parallel computation.
664
+ precision : int or str, optional
665
+ Numerical precision (8, 16, 32, 64, or 'bf16').
666
+ mode : Mode, optional
667
+ Computation mode instance.
668
+ dt : float, optional
669
+ Time step for numerical integration.
670
+ **kwargs
671
+ Additional custom environment parameters.
672
+
673
+ Raises
674
+ ------
675
+ ValueError
676
+ If invalid platform or precision is specified.
677
+ TypeError
678
+ If mode is not a Mode instance.
679
+
680
+ Examples
681
+ --------
682
+ Basic configuration:
683
+
684
+ .. code-block:: python
685
+
686
+ >>> import brainstate as bs
687
+ >>> import brainstate.environ as env
688
+ >>>
689
+ >>> # Set multiple parameters
690
+ >>> env.set(
691
+ ... precision=32,
692
+ ... dt=0.01,
693
+ ... mode=bs.mixin.Training(),
694
+ ... debug=False
695
+ ... )
696
+ >>>
697
+ >>> print(env.get('precision')) # 32
698
+ >>> print(env.get('dt')) # 0.01
699
+
700
+ Platform configuration:
701
+
702
+ .. code-block:: python
703
+
704
+ >>> import brainstate.environ as env
705
+ >>>
706
+ >>> # Configure for GPU computation
707
+ >>> env.set(platform='gpu', precision=16)
708
+ >>>
709
+ >>> # Configure for multi-core CPU
710
+ >>> env.set(platform='cpu', host_device_count=4)
711
+
712
+ Custom parameters:
713
+
714
+ .. code-block:: python
715
+
716
+ >>> import brainstate.environ as env
717
+ >>>
718
+ >>> # Set custom parameters
719
+ >>> env.set(
720
+ ... experiment_name='test_001',
721
+ ... random_seed=42,
722
+ ... log_level='DEBUG'
723
+ ... )
724
+ >>>
725
+ >>> # Retrieve custom parameters
726
+ >>> print(env.get('experiment_name')) # 'test_001'
727
+
728
+ Notes
729
+ -----
730
+ - Platform changes only take effect at program start
731
+ - Some JAX configurations require restart to take effect
732
+ - Custom parameters can be any hashable key-value pairs
733
+ """
734
+ # Handle special parameters
735
+ if platform is not None:
736
+ set_platform(platform)
737
+
738
+ if host_device_count is not None:
739
+ set_host_device_count(host_device_count)
740
+
741
+ if precision is not None:
742
+ _validate_precision(precision)
743
+ _set_jax_precision(precision)
744
+ kwargs[PRECISION] = precision
745
+
746
+ if dt is not None:
747
+ if not u.math.isscalar(dt):
748
+ raise TypeError(f"'{DT}' must be a scalar number, got {type(dt)}")
749
+ kwargs[DT] = dt
750
+
751
+ # Update global settings
752
+ _ENV_STATE.settings.update(kwargs)
753
+
754
+ # Trigger registered callbacks
755
+ for key, value in kwargs.items():
756
+ if key in _ENV_STATE.functions:
757
+ try:
758
+ _ENV_STATE.functions[key](value)
759
+ except Exception as e:
760
+ warnings.warn(
761
+ f"Callback for '{key}' raised an exception: {e}",
762
+ RuntimeWarning
763
+ )
764
+
765
+
766
+ def get_dt() -> float:
767
+ """
768
+ Get the current numerical integration time step.
769
+
770
+ Returns
771
+ -------
772
+ float
773
+ The time step value.
774
+
775
+ Raises
776
+ ------
777
+ KeyError
778
+ If dt is not set.
779
+
780
+ Examples
781
+ --------
782
+ .. code-block:: python
783
+
784
+ >>> import brainstate.environ as env
785
+ >>>
786
+ >>> env.set(dt=0.01)
787
+ >>> dt = env.get_dt()
788
+ >>> print(f"Time step: {dt} ms") # Time step: 0.01 ms
789
+ >>>
790
+ >>> # Use in computation
791
+ >>> with env.context(dt=0.001):
792
+ ... fine_dt = env.get_dt()
793
+ ... print(f"Fine time step: {fine_dt}") # 0.001
794
+ """
795
+ return get(DT)
796
+
797
+
798
+ def get_platform() -> PlatformType:
799
+ """
800
+ Get the current computing platform.
801
+
802
+ Returns
803
+ -------
804
+ str
805
+ Platform name ('cpu', 'gpu', or 'tpu').
806
+
807
+ Examples
808
+ --------
809
+ .. code-block:: python
810
+
811
+ >>> import brainstate.environ as env
812
+ >>>
813
+ >>> platform = env.get_platform()
814
+ >>> print(f"Running on: {platform}")
815
+ >>>
816
+ >>> if platform == 'gpu':
817
+ ... print("GPU acceleration available")
818
+ ... else:
819
+ ... print(f"Using {platform.upper()}")
820
+ """
821
+ return devices()[0].platform
822
+
823
+
824
+ def get_host_device_count() -> int:
825
+ """
826
+ Get the number of host devices.
827
+
828
+ Returns
829
+ -------
830
+ int
831
+ Number of host devices configured.
832
+
833
+ Examples
834
+ --------
835
+ .. code-block:: python
836
+
837
+ >>> import brainstate.environ as env
838
+ >>>
839
+ >>> # Get device count
840
+ >>> n_devices = env.get_host_device_count()
841
+ >>> print(f"Host devices: {n_devices}")
842
+ >>>
843
+ >>> # Configure for parallel computation
844
+ >>> if n_devices > 1:
845
+ ... print(f"Can use {n_devices} devices for parallel computation")
846
+ """
847
+ xla_flags = os.getenv("XLA_FLAGS", "")
848
+ match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
849
+ return int(match.group(1)) if match else 1
850
+
851
+
852
+ def set_platform(platform: PlatformType) -> None:
853
+ """
854
+ Set the computing platform.
855
+
856
+ Parameters
857
+ ----------
858
+ platform : str
859
+ Platform to use ('cpu', 'gpu', or 'tpu').
860
+
861
+ Raises
862
+ ------
863
+ ValueError
864
+ If platform is not supported.
865
+
866
+ Examples
867
+ --------
868
+ .. code-block:: python
869
+
870
+ >>> import brainstate.environ as env
871
+ >>>
872
+ >>> # Set to GPU
873
+ >>> env.set_platform('gpu')
874
+ >>>
875
+ >>> # Verify platform
876
+ >>> print(env.get_platform()) # 'gpu'
877
+
878
+ Notes
879
+ -----
880
+ Platform changes only take effect at program start. Changing platform
881
+ after JAX initialization may not have the expected effect.
882
+ """
883
+ if platform not in SUPPORTED_PLATFORMS:
884
+ raise ValueError(
885
+ f"Platform must be one of {SUPPORTED_PLATFORMS}, got '{platform}'"
886
+ )
887
+
888
+ config.update("jax_platform_name", platform)
889
+
890
+ # Trigger callbacks
891
+ if PLATFORM in _ENV_STATE.functions:
892
+ _ENV_STATE.functions[PLATFORM](platform)
893
+
894
+
895
+ def set_host_device_count(n: int) -> None:
896
+ """
897
+ Set the number of host (CPU) devices.
898
+
899
+ This function configures XLA to treat CPU cores as separate devices,
900
+ enabling parallel computation with jax.pmap on CPU.
901
+
902
+ Parameters
903
+ ----------
904
+ n : int
905
+ Number of host devices to configure.
906
+
907
+ Raises
908
+ ------
909
+ ValueError
910
+ If n is not a positive integer.
911
+
912
+ Examples
913
+ --------
914
+ .. code-block:: python
915
+
916
+ >>> import brainstate.environ as env
917
+ >>> import jax
918
+ >>>
919
+ >>> # Configure 4 CPU devices
920
+ >>> env.set_host_device_count(4)
921
+ >>>
922
+ >>> # Use with pmap
923
+ >>> def parallel_fn(x):
924
+ ... return x * 2
925
+ >>>
926
+ >>> # This will work with 4 devices
927
+ >>> pmapped_fn = jax.pmap(parallel_fn)
928
+
929
+ Warnings
930
+ --------
931
+ This setting only takes effect at program start. The effects of using
932
+ xla_force_host_platform_device_count are not fully understood and may
933
+ cause unexpected behavior.
934
+ """
935
+ if not isinstance(n, int) or n < 1:
936
+ raise ValueError(f"Host device count must be a positive integer, got {n}")
937
+
938
+ # Update XLA flags
939
+ xla_flags = os.getenv("XLA_FLAGS", "")
940
+ xla_flags = re.sub(
941
+ r"--xla_force_host_platform_device_count=\S+",
942
+ "",
943
+ xla_flags
944
+ ).split()
945
+
946
+ os.environ["XLA_FLAGS"] = " ".join(
947
+ [f"--xla_force_host_platform_device_count={n}"] + xla_flags
948
+ )
949
+
950
+ # Trigger callbacks
951
+ if HOST_DEVICE_COUNT in _ENV_STATE.functions:
952
+ _ENV_STATE.functions[HOST_DEVICE_COUNT](n)
953
+
954
+
955
+ def set_precision(precision: PrecisionType) -> None:
956
+ """
957
+ Set the global numerical precision.
958
+
959
+ Parameters
960
+ ----------
961
+ precision : int or str
962
+ Precision to use (8, 16, 32, 64, or 'bf16').
963
+
964
+ Raises
965
+ ------
966
+ ValueError
967
+ If precision is not supported.
968
+
969
+ Examples
970
+ --------
971
+ .. code-block:: python
972
+
973
+ >>> import brainstate.environ as env
974
+ >>> import jax.numpy as jnp
975
+ >>>
976
+ >>> # Set to 64-bit precision
977
+ >>> env.set_precision(64)
978
+ >>>
979
+ >>> # Arrays will use float64 by default
980
+ >>> x = jnp.array([1.0, 2.0, 3.0])
981
+ >>> print(x.dtype) # float64
982
+ >>>
983
+ >>> # Set to bfloat16 for efficiency
984
+ >>> env.set_precision('bf16')
985
+ """
986
+ _validate_precision(precision)
987
+ _set_jax_precision(precision)
988
+ _ENV_STATE.settings[PRECISION] = precision
989
+
990
+ # Trigger callbacks
991
+ if PRECISION in _ENV_STATE.functions:
992
+ _ENV_STATE.functions[PRECISION](precision)
993
+
994
+
995
+ def get_precision() -> int:
996
+ """
997
+ Get the current numerical precision as an integer.
998
+
999
+ Returns
1000
+ -------
1001
+ int
1002
+ Precision in bits (8, 16, 32, or 64).
1003
+
1004
+ Examples
1005
+ --------
1006
+ .. code-block:: python
1007
+
1008
+ >>> import brainstate.environ as env
1009
+ >>>
1010
+ >>> env.set_precision(32)
1011
+ >>> bits = env.get_precision()
1012
+ >>> print(f"Using {bits}-bit precision") # Using 32-bit precision
1013
+ >>>
1014
+ >>> # Special handling for bfloat16
1015
+ >>> env.set_precision('bf16')
1016
+ >>> print(env.get_precision()) # 16
1017
+
1018
+ Notes
1019
+ -----
1020
+ 'bf16' (bfloat16) is reported as 16-bit precision.
1021
+ """
1022
+ precision = get(PRECISION, default=DEFAULT_PRECISION)
1023
+
1024
+ if precision == 'bf16':
1025
+ return 16
1026
+ elif isinstance(precision, str):
1027
+ return int(precision)
1028
+ elif isinstance(precision, int):
1029
+ return precision
1030
+ else:
1031
+ raise ValueError(f"Invalid precision type: {type(precision)}")
1032
+
1033
+
1034
+ def _validate_precision(precision: PrecisionType) -> None:
1035
+ """Validate precision value."""
1036
+ if precision not in SUPPORTED_PRECISIONS and str(precision) not in map(str, SUPPORTED_PRECISIONS):
1037
+ raise ValueError(
1038
+ f"Precision must be one of {SUPPORTED_PRECISIONS}, got {precision}"
1039
+ )
1040
+
1041
+
1042
+ def _get_precision() -> PrecisionType:
1043
+ """Get raw precision value (including 'bf16')."""
1044
+ return get(PRECISION, default=DEFAULT_PRECISION)
1045
+
1046
+
1047
+ def _set_jax_precision(precision: PrecisionType) -> None:
1048
+ """Configure JAX precision settings."""
1049
+ # Enable/disable 64-bit mode
1050
+ if precision in (64, '64'):
1051
+ config.update("jax_enable_x64", True)
1052
+ else:
1053
+ config.update("jax_enable_x64", False)
1054
+
1055
+
1056
+ @functools.lru_cache(maxsize=16)
1057
+ def _get_uint(precision: PrecisionType) -> DTypeLike:
1058
+ """Get unsigned integer type for given precision."""
1059
+ if precision in (64, '64'):
1060
+ return np.uint64
1061
+ elif precision in (32, '32'):
1062
+ return np.uint32
1063
+ elif precision in (16, '16', 'bf16'):
1064
+ return np.uint16
1065
+ elif precision in (8, '8'):
1066
+ return np.uint8
1067
+ else:
1068
+ raise ValueError(f"Unsupported precision: {precision}")
1069
+
1070
+
1071
+ @functools.lru_cache(maxsize=16)
1072
+ def _get_int(precision: PrecisionType) -> DTypeLike:
1073
+ """Get integer type for given precision."""
1074
+ if precision in (64, '64'):
1075
+ return np.int64
1076
+ elif precision in (32, '32'):
1077
+ return np.int32
1078
+ elif precision in (16, '16', 'bf16'):
1079
+ return np.int16
1080
+ elif precision in (8, '8'):
1081
+ return np.int8
1082
+ else:
1083
+ raise ValueError(f"Unsupported precision: {precision}")
1084
+
1085
+
1086
+ @functools.lru_cache(maxsize=16)
1087
+ def _get_float(precision: PrecisionType) -> DTypeLike:
1088
+ """Get floating-point type for given precision."""
1089
+ if precision in (64, '64'):
1090
+ return np.float64
1091
+ elif precision in (32, '32'):
1092
+ return np.float32
1093
+ elif precision in (16, '16'):
1094
+ return np.float16
1095
+ elif precision == 'bf16':
1096
+ return jnp.bfloat16
1097
+ elif precision in (8, '8'):
1098
+ return jnp.float8_e5m2
1099
+ else:
1100
+ raise ValueError(f"Unsupported precision: {precision}")
1101
+
1102
+
1103
+ @functools.lru_cache(maxsize=16)
1104
+ def _get_complex(precision: PrecisionType) -> DTypeLike:
1105
+ """Get complex type for given precision."""
1106
+ if precision in (64, '64'):
1107
+ return np.complex128
1108
+ elif precision in (32, '32', 16, '16', 'bf16', 8, '8'):
1109
+ return np.complex64
1110
+ else:
1111
+ raise ValueError(f"Unsupported precision: {precision}")
1112
+
1113
+
1114
+ def dftype() -> DTypeLike:
1115
+ """
1116
+ Get the default floating-point data type.
1117
+
1118
+ This function returns the appropriate floating-point type based on
1119
+ the current precision setting, allowing dynamic type selection.
1120
+
1121
+ Returns
1122
+ -------
1123
+ DTypeLike
1124
+ Default floating-point data type.
1125
+
1126
+ Examples
1127
+ --------
1128
+ .. code-block:: python
1129
+
1130
+ >>> import brainstate.environ as env
1131
+ >>> import jax.numpy as jnp
1132
+ >>>
1133
+ >>> # With 32-bit precision
1134
+ >>> env.set(precision=32)
1135
+ >>> x = jnp.zeros(10, dtype=env.dftype())
1136
+ >>> print(x.dtype) # float32
1137
+ >>>
1138
+ >>> # With 64-bit precision
1139
+ >>> with env.context(precision=64):
1140
+ ... y = jnp.ones(5, dtype=env.dftype())
1141
+ ... print(y.dtype) # float64
1142
+ >>>
1143
+ >>> # With bfloat16
1144
+ >>> env.set(precision='bf16')
1145
+ >>> z = jnp.array([1, 2, 3], dtype=env.dftype())
1146
+ >>> print(z.dtype) # bfloat16
1147
+
1148
+ See Also
1149
+ --------
1150
+ ditype : Default integer type
1151
+ dutype : Default unsigned integer type
1152
+ dctype : Default complex type
1153
+ """
1154
+ return _get_float(_get_precision())
1155
+
1156
+
1157
+ def ditype() -> DTypeLike:
1158
+ """
1159
+ Get the default integer data type.
1160
+
1161
+ This function returns the appropriate integer type based on
1162
+ the current precision setting.
1163
+
1164
+ Returns
1165
+ -------
1166
+ DTypeLike
1167
+ Default integer data type.
1168
+
1169
+ Examples
1170
+ --------
1171
+ .. code-block:: python
1172
+
1173
+ >>> import brainstate.environ as env
1174
+ >>> import jax.numpy as jnp
1175
+ >>>
1176
+ >>> # With 32-bit precision
1177
+ >>> env.set(precision=32)
1178
+ >>> indices = jnp.arange(10, dtype=env.ditype())
1179
+ >>> print(indices.dtype) # int32
1180
+ >>>
1181
+ >>> # With 64-bit precision
1182
+ >>> with env.context(precision=64):
1183
+ ... big_indices = jnp.arange(1000, dtype=env.ditype())
1184
+ ... print(big_indices.dtype) # int64
1185
+
1186
+ See Also
1187
+ --------
1188
+ dftype : Default floating-point type
1189
+ dutype : Default unsigned integer type
1190
+ """
1191
+ return _get_int(_get_precision())
1192
+
1193
+
1194
+ def dutype() -> DTypeLike:
1195
+ """
1196
+ Get the default unsigned integer data type.
1197
+
1198
+ This function returns the appropriate unsigned integer type based on
1199
+ the current precision setting.
1200
+
1201
+ Returns
1202
+ -------
1203
+ DTypeLike
1204
+ Default unsigned integer data type.
1205
+
1206
+ Examples
1207
+ --------
1208
+ .. code-block:: python
1209
+
1210
+ >>> import brainstate.environ as env
1211
+ >>> import jax.numpy as jnp
1212
+ >>>
1213
+ >>> # With 32-bit precision
1214
+ >>> env.set(precision=32)
1215
+ >>> counts = jnp.array([10, 20, 30], dtype=env.dutype())
1216
+ >>> print(counts.dtype) # uint32
1217
+ >>>
1218
+ >>> # With 16-bit precision
1219
+ >>> with env.context(precision=16):
1220
+ ... small_counts = jnp.array([1, 2, 3], dtype=env.dutype())
1221
+ ... print(small_counts.dtype) # uint16
1222
+
1223
+ See Also
1224
+ --------
1225
+ ditype : Default signed integer type
1226
+ """
1227
+ return _get_uint(_get_precision())
1228
+
1229
+
1230
+ def dctype() -> DTypeLike:
1231
+ """
1232
+ Get the default complex data type.
1233
+
1234
+ This function returns the appropriate complex type based on
1235
+ the current precision setting.
1236
+
1237
+ Returns
1238
+ -------
1239
+ DTypeLike
1240
+ Default complex data type.
1241
+
1242
+ Examples
1243
+ --------
1244
+ .. code-block:: python
1245
+
1246
+ >>> import brainstate.environ as env
1247
+ >>> import jax.numpy as jnp
1248
+ >>>
1249
+ >>> # With 32-bit precision
1250
+ >>> env.set(precision=32)
1251
+ >>> z = jnp.array([1+2j, 3+4j], dtype=env.dctype())
1252
+ >>> print(z.dtype) # complex64
1253
+ >>>
1254
+ >>> # With 64-bit precision
1255
+ >>> with env.context(precision=64):
1256
+ ... w = jnp.array([5+6j], dtype=env.dctype())
1257
+ ... print(w.dtype) # complex128
1258
+
1259
+ Notes
1260
+ -----
1261
+ Complex128 is only available with 64-bit precision.
1262
+ All other precisions use complex64.
1263
+ """
1264
+ return _get_complex(_get_precision())
1265
+
1266
+
1267
+ def tolerance() -> jnp.ndarray:
1268
+ """
1269
+ Get numerical tolerance based on current precision.
1270
+
1271
+ This function returns an appropriate tolerance value for numerical
1272
+ comparisons based on the current precision setting.
1273
+
1274
+ Returns
1275
+ -------
1276
+ jnp.ndarray
1277
+ Tolerance value as a scalar array.
1278
+
1279
+ Examples
1280
+ --------
1281
+ .. code-block:: python
1282
+
1283
+ >>> import brainstate.environ as env
1284
+ >>> import jax.numpy as jnp
1285
+ >>>
1286
+ >>> # Different tolerances for different precisions
1287
+ >>> env.set(precision=64)
1288
+ >>> tol64 = env.tolerance()
1289
+ >>> print(f"64-bit tolerance: {tol64}") # 1e-12
1290
+ >>>
1291
+ >>> env.set(precision=32)
1292
+ >>> tol32 = env.tolerance()
1293
+ >>> print(f"32-bit tolerance: {tol32}") # 1e-5
1294
+ >>>
1295
+ >>> # Use in numerical comparisons
1296
+ >>> def are_close(a, b):
1297
+ ... return jnp.abs(a - b) < env.tolerance()
1298
+
1299
+ Notes
1300
+ -----
1301
+ Tolerance values:
1302
+ - 64-bit: 1e-12
1303
+ - 32-bit: 1e-5
1304
+ - 16-bit and below: 1e-2
1305
+ """
1306
+ precision = get_precision()
1307
+
1308
+ if precision == 64:
1309
+ return jnp.array(1e-12, dtype=np.float64)
1310
+ elif precision == 32:
1311
+ return jnp.array(1e-5, dtype=np.float32)
1312
+ else:
1313
+ return jnp.array(1e-2, dtype=np.float16)
1314
+
1315
+
1316
+ def register_default_behavior(
1317
+ key: str,
1318
+ behavior: Callable[[Any], None],
1319
+ replace_if_exist: bool = False
1320
+ ) -> None:
1321
+ """
1322
+ Register a callback for environment parameter changes.
1323
+
1324
+ This function allows you to register custom behaviors that are
1325
+ triggered whenever a specific environment parameter is modified.
1326
+
1327
+ Parameters
1328
+ ----------
1329
+ key : str
1330
+ Environment parameter key to monitor.
1331
+ behavior : Callable[[Any], None]
1332
+ Callback function that receives the new value.
1333
+ replace_if_exist : bool, default=False
1334
+ Whether to replace existing callback for this key.
1335
+
1336
+ Raises
1337
+ ------
1338
+ TypeError
1339
+ If behavior is not callable.
1340
+ ValueError
1341
+ If key already has a registered behavior and replace_if_exist is False.
1342
+
1343
+ Examples
1344
+ --------
1345
+ Basic callback registration:
1346
+
1347
+ .. code-block:: python
1348
+
1349
+ >>> import brainstate.environ as env
1350
+ >>>
1351
+ >>> # Define a callback
1352
+ >>> def on_dt_change(new_dt):
1353
+ ... print(f"Time step changed to: {new_dt}")
1354
+ >>>
1355
+ >>> # Register the callback
1356
+ >>> env.register_default_behavior('dt', on_dt_change)
1357
+ >>>
1358
+ >>> # Callback is triggered on changes
1359
+ >>> env.set(dt=0.01) # Prints: Time step changed to: 0.01
1360
+ >>>
1361
+ >>> with env.context(dt=0.001): # Prints: Time step changed to: 0.001
1362
+ ... pass # Prints: Time step changed to: 0.01 (on exit)
1363
+
1364
+ Complex behavior with validation:
1365
+
1366
+ .. code-block:: python
1367
+
1368
+ >>> import brainstate.environ as env
1369
+ >>>
1370
+ >>> def validate_batch_size(size):
1371
+ ... if not isinstance(size, int) or size <= 0:
1372
+ ... raise ValueError(f"Invalid batch size: {size}")
1373
+ ... if size > 1024:
1374
+ ... print(f"Warning: Large batch size {size} may cause OOM")
1375
+ >>>
1376
+ >>> env.register_default_behavior('batch_size', validate_batch_size)
1377
+ >>>
1378
+ >>> # Valid setting
1379
+ >>> env.set(batch_size=32) # OK
1380
+ >>>
1381
+ >>> # Invalid setting
1382
+ >>> # env.set(batch_size=-1) # Raises ValueError
1383
+
1384
+ Replacing existing behavior:
1385
+
1386
+ .. code-block:: python
1387
+
1388
+ >>> import brainstate.environ as env
1389
+ >>>
1390
+ >>> def old_behavior(value):
1391
+ ... print(f"Old: {value}")
1392
+ >>>
1393
+ >>> def new_behavior(value):
1394
+ ... print(f"New: {value}")
1395
+ >>>
1396
+ >>> env.register_default_behavior('key', old_behavior)
1397
+ >>> env.register_default_behavior('key', new_behavior, replace_if_exist=True)
1398
+ >>>
1399
+ >>> env.set(key='test') # Prints: New: test
1400
+
1401
+ See Also
1402
+ --------
1403
+ unregister_default_behavior : Remove registered callbacks
1404
+ list_registered_behaviors : List all registered callbacks
1405
+ """
1406
+ if not isinstance(key, str):
1407
+ raise TypeError(f"Key must be a string, got {type(key)}")
1408
+
1409
+ if not callable(behavior):
1410
+ raise TypeError(f"Behavior must be callable, got {type(behavior)}")
1411
+
1412
+ if key in _ENV_STATE.functions and not replace_if_exist:
1413
+ raise ValueError(
1414
+ f"Behavior for key '{key}' already registered. "
1415
+ f"Use replace_if_exist=True to override."
1416
+ )
1417
+
1418
+ _ENV_STATE.functions[key] = behavior
1419
+
1420
+
1421
+ def unregister_default_behavior(key: str) -> bool:
1422
+ """
1423
+ Remove a registered callback for an environment parameter.
1424
+
1425
+ Parameters
1426
+ ----------
1427
+ key : str
1428
+ Environment parameter key.
1429
+
1430
+ Returns
1431
+ -------
1432
+ bool
1433
+ True if a callback was removed, False if no callback existed.
1434
+
1435
+ Examples
1436
+ --------
1437
+ .. code-block:: python
1438
+
1439
+ >>> import brainstate.environ as env
1440
+ >>>
1441
+ >>> # Register a callback
1442
+ >>> def callback(value):
1443
+ ... print(f"Value: {value}")
1444
+ >>>
1445
+ >>> env.register_default_behavior('param', callback)
1446
+ >>>
1447
+ >>> # Remove the callback
1448
+ >>> removed = env.unregister_default_behavior('param')
1449
+ >>> print(f"Callback removed: {removed}") # True
1450
+ >>>
1451
+ >>> # No callback triggers now
1452
+ >>> env.set(param='test') # No output
1453
+ >>>
1454
+ >>> # Removing non-existent callback
1455
+ >>> removed = env.unregister_default_behavior('nonexistent')
1456
+ >>> print(f"Callback removed: {removed}") # False
1457
+ """
1458
+ if key in _ENV_STATE.functions:
1459
+ del _ENV_STATE.functions[key]
1460
+ return True
1461
+ return False
1462
+
1463
+
1464
+ def list_registered_behaviors() -> List[str]:
1465
+ """
1466
+ List all keys with registered callbacks.
1467
+
1468
+ Returns
1469
+ -------
1470
+ list of str
1471
+ Keys that have registered behavior callbacks.
1472
+
1473
+ Examples
1474
+ --------
1475
+ .. code-block:: python
1476
+
1477
+ >>> import brainstate.environ as env
1478
+ >>>
1479
+ >>> # Register some callbacks
1480
+ >>> env.register_default_behavior('param1', lambda x: None)
1481
+ >>> env.register_default_behavior('param2', lambda x: None)
1482
+ >>>
1483
+ >>> # List registered behaviors
1484
+ >>> behaviors = env.list_registered_behaviors()
1485
+ >>> print(f"Registered: {behaviors}") # ['param1', 'param2']
1486
+ >>>
1487
+ >>> # Check if specific behavior is registered
1488
+ >>> if 'dt' in behaviors:
1489
+ ... print("dt has a registered callback")
1490
+ """
1491
+ return list(_ENV_STATE.functions.keys())
1492
+
1493
+
1494
+ # Initialize default precision on module load
1495
+ set(precision=DEFAULT_PRECISION)