brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/environ.py CHANGED
@@ -1,563 +1,563 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import contextlib
19
- import dataclasses
20
- import functools
21
- import os
22
- import re
23
- import threading
24
- from collections import defaultdict
25
- from typing import Any, Callable, Dict, Hashable
26
-
27
- import numpy as np
28
- from jax import config, devices, numpy as jnp
29
- from jax.typing import DTypeLike
30
-
31
- from .mixin import Mode
32
-
33
- __all__ = [
34
- # functions for environment settings
35
- 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
36
- # functions for getting default behaviors
37
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
38
- # functions for default data types
39
- 'dftype', 'ditype', 'dutype', 'dctype',
40
- # others
41
- 'tolerance', 'register_default_behavior',
42
- ]
43
-
44
- # Default, there are several shared arguments in the global context.
45
- I = 'i' # the index of the current computation.
46
- T = 't' # the current time of the current computation.
47
- JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
48
- FIT = 'fit' # whether to fit the model.
49
-
50
-
51
- @dataclasses.dataclass
52
- class DefaultContext(threading.local):
53
- # default environment settings
54
- settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
55
- # current environment settings
56
- contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
57
- # environment functions
58
- functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
59
-
60
-
61
- DFAULT = DefaultContext()
62
- _NOT_PROVIDE = object()
63
-
64
-
65
- @contextlib.contextmanager
66
- def context(**kwargs):
67
- r"""
68
- Context-manager that sets a computing environment for brain dynamics computation.
69
-
70
- In BrainPy, there are several basic computation settings when constructing models,
71
- including ``mode`` for controlling model computing behavior, ``dt`` for numerical
72
- integration, ``int_`` for integer precision, and ``float_`` for floating precision.
73
- :py:class:`~.environment`` provides a context for model construction and
74
- computation. In this temporal environment, models are constructed with the given
75
- ``mode``, ``dt``, ``int_``, etc., environment settings.
76
-
77
- For instance::
78
-
79
- >>> import brainstate as brainstate
80
- >>> with brainstate.environ.context(dt=0.1) as env:
81
- ... dt = brainstate.environ.get('dt')
82
- ... print(env)
83
-
84
- """
85
- if 'platform' in kwargs:
86
- raise ValueError('\n'
87
- 'Cannot set platform in "context" environment. \n'
88
- 'You should set platform in the global environment by "set_platform()" or "set()".')
89
- if 'host_device_count' in kwargs:
90
- raise ValueError('Cannot set host_device_count in environment context. '
91
- 'Please use set_host_device_count() or set() for the global setting.')
92
-
93
- if 'precision' in kwargs:
94
- last_precision = _get_precision()
95
- _set_jax_precision(kwargs['precision'])
96
-
97
- try:
98
- for k, v in kwargs.items():
99
-
100
- # update the current environment
101
- DFAULT.contexts[k].append(v)
102
-
103
- # restore the environment functions
104
- if k in DFAULT.functions:
105
- DFAULT.functions[k](v)
106
-
107
- # yield the current all environment information
108
- yield all()
109
- finally:
110
-
111
- for k, v in kwargs.items():
112
-
113
- # restore the current environment
114
- DFAULT.contexts[k].pop()
115
-
116
- # restore the environment functions
117
- if k in DFAULT.functions:
118
- DFAULT.functions[k](get(k))
119
-
120
- if 'precision' in kwargs:
121
- _set_jax_precision(last_precision)
122
-
123
-
124
- def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
125
- """
126
- Get one of the default computation environment.
127
-
128
- Returns
129
- -------
130
- item: Any
131
- The default computation environment.
132
- """
133
- if key == 'platform':
134
- return get_platform()
135
-
136
- if key == 'host_device_count':
137
- return get_host_device_count()
138
-
139
- if key in DFAULT.contexts:
140
- if len(DFAULT.contexts[key]) > 0:
141
- return DFAULT.contexts[key][-1]
142
- if key in DFAULT.settings:
143
- return DFAULT.settings[key]
144
-
145
- if default is _NOT_PROVIDE:
146
- if desc is not None:
147
- raise KeyError(
148
- f"'{key}' is not found in the context. \n"
149
- f"You can set it by `brainstate.share.context({key}=value)` "
150
- f"locally or `brainstate.share.set({key}=value)` globally. \n"
151
- f"Description: {desc}"
152
- )
153
- else:
154
- raise KeyError(
155
- f"'{key}' is not found in the context. \n"
156
- f"You can set it by `brainstate.share.context({key}=value)` "
157
- f"locally or `brainstate.share.set({key}=value)` globally."
158
- )
159
- return default
160
-
161
-
162
- def all() -> dict:
163
- """
164
- Get all the current default computation environment.
165
-
166
- Returns
167
- -------
168
- r: dict
169
- The current default computation environment.
170
- """
171
- r = dict()
172
- for k, v in DFAULT.contexts.items():
173
- if v:
174
- r[k] = v[-1]
175
- for k, v in DFAULT.settings.items():
176
- if k not in r:
177
- r[k] = v
178
- return r
179
-
180
-
181
- def get_dt():
182
- """Get the numerical integrator precision.
183
-
184
- Returns
185
- -------
186
- dt : float
187
- Numerical integration precision.
188
- """
189
- return get('dt')
190
-
191
-
192
- def get_mode() -> Mode:
193
- """Get the default computing mode.
194
-
195
- References
196
- ----------
197
- mode: Mode
198
- The default computing mode.
199
- """
200
- return get('mode')
201
-
202
-
203
- def get_platform() -> str:
204
- """Get the computing platform.
205
-
206
- Returns
207
- -------
208
- platform: str
209
- Either 'cpu', 'gpu' or 'tpu'.
210
- """
211
- return devices()[0].platform
212
-
213
-
214
- def get_host_device_count():
215
- """
216
- Get the number of host devices.
217
-
218
- Returns
219
- -------
220
- n: int
221
- The number of host devices.
222
- """
223
- xla_flags = os.getenv("XLA_FLAGS", "")
224
- match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
225
- return int(match.group(1)) if match else 1
226
-
227
-
228
- def _get_precision() -> int | str:
229
- """
230
- Get the default precision.
231
-
232
- Returns
233
- -------
234
- precision: int
235
- The default precision.
236
- """
237
- return get('precision')
238
-
239
-
240
- def get_precision() -> int:
241
- """
242
- Get the default precision.
243
-
244
- Returns
245
- -------
246
- precision: int
247
- The default precision.
248
- """
249
- precision = get('precision')
250
- if precision == 'bf16':
251
- return 16
252
- if isinstance(precision, int):
253
- return precision
254
- if isinstance(precision, str):
255
- return int(precision)
256
- raise ValueError(f'Unsupported precision: {precision}')
257
-
258
-
259
- def set(
260
- platform: str = None,
261
- host_device_count: int = None,
262
- precision: int | str = None,
263
- mode: Mode = None,
264
- **kwargs
265
- ):
266
- """
267
- Set the global default computation environment.
268
-
269
-
270
-
271
- Args:
272
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
273
- host_device_count: int. The number of host devices.
274
- precision: int, str. The default precision.
275
- mode: Mode. The computing mode.
276
- **kwargs: dict. Other environment settings.
277
- """
278
- if platform is not None:
279
- set_platform(platform)
280
- if host_device_count is not None:
281
- set_host_device_count(host_device_count)
282
- if precision is not None:
283
- _set_jax_precision(precision)
284
- kwargs['precision'] = precision
285
- if mode is not None:
286
- assert isinstance(mode, Mode), 'mode must be a Mode instance.'
287
- kwargs['mode'] = mode
288
-
289
- # set default environment
290
- DFAULT.settings.update(kwargs)
291
-
292
- # update the environment functions
293
- for k, v in kwargs.items():
294
- if k in DFAULT.functions:
295
- DFAULT.functions[k](v)
296
-
297
-
298
- def set_host_device_count(n):
299
- """
300
- By default, XLA considers all CPU cores as one device. This utility tells XLA
301
- that there are `n` host (CPU) devices available to use. As a consequence, this
302
- allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
303
-
304
- .. note:: This utility only takes effect at the beginning of your program.
305
- Under the hood, this sets the environment variable
306
- `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
307
- `[num_device]` is the desired number of CPU devices `n`.
308
-
309
- .. warning:: Our understanding of the side effects of using the
310
- `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
311
- observe some strange phenomenon when using this utility, please let us
312
- know through our issue or forum page. More information is available in this
313
- `JAX issue <https://github.com/google/jax/issues/1408>`_.
314
-
315
- :param int n: number of devices to use.
316
- """
317
- xla_flags = os.getenv("XLA_FLAGS", "")
318
- xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
319
- os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
320
-
321
- # update the environment functions
322
- if 'host_device_count' in DFAULT.functions:
323
- DFAULT.functions['host_device_count'](n)
324
-
325
-
326
- def set_platform(platform: str):
327
- """
328
- Changes platform to CPU, GPU, or TPU. This utility only takes
329
- effect at the beginning of your program.
330
-
331
- Args:
332
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
333
-
334
- Raises:
335
- ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
336
- """
337
- assert platform in ['cpu', 'gpu', 'tpu']
338
- config.update("jax_platform_name", platform)
339
-
340
- # update the environment functions
341
- if 'platform' in DFAULT.functions:
342
- DFAULT.functions['platform'](platform)
343
-
344
-
345
- def _set_jax_precision(precision: int | str):
346
- """
347
- Set the default precision.
348
-
349
- Args:
350
- precision: int. The default precision.
351
- """
352
- # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
353
- if precision in [64, '64']:
354
- config.update("jax_enable_x64", True)
355
- else:
356
- config.update("jax_enable_x64", False)
357
-
358
-
359
- @functools.lru_cache()
360
- def _get_uint(precision: int):
361
- if precision in [64, '64']:
362
- return np.uint64
363
- elif precision in [32, '32']:
364
- return np.uint32
365
- elif precision in [16, '16', 'bf16']:
366
- return np.uint16
367
- elif precision in [8, '8']:
368
- return np.uint8
369
- else:
370
- raise ValueError(f'Unsupported precision: {precision}')
371
-
372
-
373
- @functools.lru_cache()
374
- def _get_int(precision: int):
375
- if precision in [64, '64']:
376
- return np.int64
377
- elif precision in [32, '32']:
378
- return np.int32
379
- elif precision in [16, '16', 'bf16']:
380
- return np.int16
381
- elif precision in [8, '8']:
382
- return np.int8
383
- else:
384
- raise ValueError(f'Unsupported precision: {precision}')
385
-
386
-
387
- @functools.lru_cache()
388
- def _get_float(precision: int):
389
- if precision in [64, '64']:
390
- return np.float64
391
- elif precision in [32, '32']:
392
- return np.float32
393
- elif precision in [16, '16']:
394
- return np.float16
395
- elif precision in ['bf16']:
396
- return jnp.bfloat16
397
- elif precision in [8, '8']:
398
- return jnp.float8_e5m2
399
- else:
400
- raise ValueError(f'Unsupported precision: {precision}')
401
-
402
-
403
- @functools.lru_cache()
404
- def _get_complex(precision: int):
405
- if precision == [64, '64']:
406
- return np.complex128
407
- elif precision == [32, '32']:
408
- return np.complex64
409
- elif precision in [16, '16', 'bf16']:
410
- return np.complex64
411
- elif precision == [8, '8']:
412
- return np.complex64
413
- else:
414
- raise ValueError(f'Unsupported precision: {precision}')
415
-
416
-
417
- def dftype() -> DTypeLike:
418
- """
419
- Default floating data type.
420
-
421
- This function returns the default floating data type based on the current precision.
422
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
423
- you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
424
-
425
- For example, if the precision is set to 32, the default floating data type is ``np.float32``.
426
-
427
- >>> import brainstate as brainstate
428
- >>> import numpy as np
429
- >>> with brainstate.environ.context(precision=32):
430
- ... a = np.zeros(1, dtype=brainstate.environ.dftype())
431
- >>> print(a.dtype)
432
-
433
- Returns
434
- -------
435
- float_dtype: DTypeLike
436
- The default floating data type.
437
- """
438
- return _get_float(_get_precision())
439
-
440
-
441
- def ditype() -> DTypeLike:
442
- """
443
- Default integer data type.
444
-
445
- This function returns the default integer data type based on the current precision.
446
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
447
- you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
448
-
449
- For example, if the precision is set to 32, the default integer data type is ``np.int32``.
450
-
451
- >>> import brainstate as brainstate
452
- >>> import numpy as np
453
- >>> with brainstate.environ.context(precision=32):
454
- ... a = np.zeros(1, dtype=brainstate.environ.ditype())
455
- >>> print(a.dtype)
456
- int32
457
-
458
- Returns
459
- -------
460
- int_dtype: DTypeLike
461
- The default integer data type.
462
- """
463
- return _get_int(_get_precision())
464
-
465
-
466
- def dutype() -> DTypeLike:
467
- """
468
- Default unsigned integer data type.
469
-
470
- This function returns the default unsigned integer data type based on the current precision.
471
- If you want the data dtype is changed with the setting of the precision
472
- by ``brainstate.environ.set(precision)``, you can use this function to get the default
473
- unsigned integer data type, and create the data by using ``dtype=dutype()``.
474
-
475
- For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
476
-
477
- >>> import brainstate as brainstate
478
- >>> import numpy as np
479
- >>> with brainstate.environ.context(precision=32):
480
- ... a = np.zeros(1, dtype=brainstate.environ.dutype())
481
- >>> print(a.dtype)
482
- uint32
483
-
484
- Returns
485
- -------
486
- uint_dtype: DTypeLike
487
- The default unsigned integer data type.
488
- """
489
- return _get_uint(_get_precision())
490
-
491
-
492
- def dctype() -> DTypeLike:
493
- """
494
- Default complex data type.
495
-
496
- This function returns the default complex data type based on the current precision.
497
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
498
- you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
499
-
500
- For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
501
-
502
- >>> import brainstate as brainstate
503
- >>> import numpy as np
504
- >>> with brainstate.environ.context(precision=32):
505
- ... a = np.zeros(1, dtype=brainstate.environ.dctype())
506
- >>> print(a.dtype)
507
- complex64
508
-
509
- Returns
510
- -------
511
- complex_dtype: DTypeLike
512
- The default complex data type.
513
- """
514
- return _get_complex(_get_precision())
515
-
516
-
517
- def tolerance():
518
- if get_precision() == 64:
519
- return jnp.array(1e-12, dtype=np.float64)
520
- elif get_precision() == 32:
521
- return jnp.array(1e-5, dtype=np.float32)
522
- else:
523
- return jnp.array(1e-2, dtype=np.float16)
524
-
525
-
526
- def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
527
- """
528
- Register a default behavior for a specific global key parameter.
529
-
530
- For example, you can register a default behavior for the key 'dt' by::
531
-
532
- >>> import brainstate as brainstate
533
- >>> def dt_behavior(dt):
534
- ... print(f'Set the default dt to {dt}.')
535
- ...
536
- >>> brainstate.environ.register_default_behavior('dt', dt_behavior)
537
-
538
- Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
539
- `dt_behavior` will be called with
540
- `dt_behavior(0.1)`.
541
-
542
- >>> brainstate.environ.set(dt=0.1)
543
- Set the default dt to 0.1.
544
- >>> with brainstate.environ.context(dt=0.2):
545
- ... pass
546
- Set the default dt to 0.2.
547
- Set the default dt to 0.1.
548
-
549
-
550
- Args:
551
- key: str. The key to register.
552
- behavior: Callable. The behavior to register. It should be a callable.
553
- replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
554
-
555
- """
556
- assert isinstance(key, str), 'key must be a string.'
557
- assert callable(behavior), 'behavior must be a callable.'
558
- if not replace_if_exist:
559
- assert key not in DFAULT.functions, f'{key} has been registered.'
560
- DFAULT.functions[key] = behavior
561
-
562
-
563
- set(precision=32)
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import contextlib
19
+ import dataclasses
20
+ import functools
21
+ import os
22
+ import re
23
+ import threading
24
+ from collections import defaultdict
25
+ from typing import Any, Callable, Dict, Hashable
26
+
27
+ import numpy as np
28
+ from jax import config, devices, numpy as jnp
29
+ from jax.typing import DTypeLike
30
+
31
+ from .mixin import Mode
32
+
33
+ __all__ = [
34
+ # functions for environment settings
35
+ 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
36
+ # functions for getting default behaviors
37
+ 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
38
+ # functions for default data types
39
+ 'dftype', 'ditype', 'dutype', 'dctype',
40
+ # others
41
+ 'tolerance', 'register_default_behavior',
42
+ ]
43
+
44
+ # Default, there are several shared arguments in the global context.
45
+ I = 'i' # the index of the current computation.
46
+ T = 't' # the current time of the current computation.
47
+ JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
48
+ FIT = 'fit' # whether to fit the model.
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class DefaultContext(threading.local):
53
+ # default environment settings
54
+ settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
55
+ # current environment settings
56
+ contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
57
+ # environment functions
58
+ functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
59
+
60
+
61
+ DFAULT = DefaultContext()
62
+ _NOT_PROVIDE = object()
63
+
64
+
65
+ @contextlib.contextmanager
66
+ def context(**kwargs):
67
+ r"""
68
+ Context-manager that sets a computing environment for brain dynamics computation.
69
+
70
+ In BrainPy, there are several basic computation settings when constructing models,
71
+ including ``mode`` for controlling model computing behavior, ``dt`` for numerical
72
+ integration, ``int_`` for integer precision, and ``float_`` for floating precision.
73
+ :py:class:`~.environment`` provides a context for model construction and
74
+ computation. In this temporal environment, models are constructed with the given
75
+ ``mode``, ``dt``, ``int_``, etc., environment settings.
76
+
77
+ For instance::
78
+
79
+ >>> import brainstate as brainstate
80
+ >>> with brainstate.environ.context(dt=0.1) as env:
81
+ ... dt = brainstate.environ.get('dt')
82
+ ... print(env)
83
+
84
+ """
85
+ if 'platform' in kwargs:
86
+ raise ValueError('\n'
87
+ 'Cannot set platform in "context" environment. \n'
88
+ 'You should set platform in the global environment by "set_platform()" or "set()".')
89
+ if 'host_device_count' in kwargs:
90
+ raise ValueError('Cannot set host_device_count in environment context. '
91
+ 'Please use set_host_device_count() or set() for the global setting.')
92
+
93
+ if 'precision' in kwargs:
94
+ last_precision = _get_precision()
95
+ _set_jax_precision(kwargs['precision'])
96
+
97
+ try:
98
+ for k, v in kwargs.items():
99
+
100
+ # update the current environment
101
+ DFAULT.contexts[k].append(v)
102
+
103
+ # restore the environment functions
104
+ if k in DFAULT.functions:
105
+ DFAULT.functions[k](v)
106
+
107
+ # yield the current all environment information
108
+ yield all()
109
+ finally:
110
+
111
+ for k, v in kwargs.items():
112
+
113
+ # restore the current environment
114
+ DFAULT.contexts[k].pop()
115
+
116
+ # restore the environment functions
117
+ if k in DFAULT.functions:
118
+ DFAULT.functions[k](get(k))
119
+
120
+ if 'precision' in kwargs:
121
+ _set_jax_precision(last_precision)
122
+
123
+
124
+ def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
125
+ """
126
+ Get one of the default computation environment.
127
+
128
+ Returns
129
+ -------
130
+ item: Any
131
+ The default computation environment.
132
+ """
133
+ if key == 'platform':
134
+ return get_platform()
135
+
136
+ if key == 'host_device_count':
137
+ return get_host_device_count()
138
+
139
+ if key in DFAULT.contexts:
140
+ if len(DFAULT.contexts[key]) > 0:
141
+ return DFAULT.contexts[key][-1]
142
+ if key in DFAULT.settings:
143
+ return DFAULT.settings[key]
144
+
145
+ if default is _NOT_PROVIDE:
146
+ if desc is not None:
147
+ raise KeyError(
148
+ f"'{key}' is not found in the context. \n"
149
+ f"You can set it by `brainstate.share.context({key}=value)` "
150
+ f"locally or `brainstate.share.set({key}=value)` globally. \n"
151
+ f"Description: {desc}"
152
+ )
153
+ else:
154
+ raise KeyError(
155
+ f"'{key}' is not found in the context. \n"
156
+ f"You can set it by `brainstate.share.context({key}=value)` "
157
+ f"locally or `brainstate.share.set({key}=value)` globally."
158
+ )
159
+ return default
160
+
161
+
162
+ def all() -> dict:
163
+ """
164
+ Get all the current default computation environment.
165
+
166
+ Returns
167
+ -------
168
+ r: dict
169
+ The current default computation environment.
170
+ """
171
+ r = dict()
172
+ for k, v in DFAULT.contexts.items():
173
+ if v:
174
+ r[k] = v[-1]
175
+ for k, v in DFAULT.settings.items():
176
+ if k not in r:
177
+ r[k] = v
178
+ return r
179
+
180
+
181
+ def get_dt():
182
+ """Get the numerical integrator precision.
183
+
184
+ Returns
185
+ -------
186
+ dt : float
187
+ Numerical integration precision.
188
+ """
189
+ return get('dt')
190
+
191
+
192
+ def get_mode() -> Mode:
193
+ """Get the default computing mode.
194
+
195
+ References
196
+ ----------
197
+ mode: Mode
198
+ The default computing mode.
199
+ """
200
+ return get('mode')
201
+
202
+
203
+ def get_platform() -> str:
204
+ """Get the computing platform.
205
+
206
+ Returns
207
+ -------
208
+ platform: str
209
+ Either 'cpu', 'gpu' or 'tpu'.
210
+ """
211
+ return devices()[0].platform
212
+
213
+
214
+ def get_host_device_count():
215
+ """
216
+ Get the number of host devices.
217
+
218
+ Returns
219
+ -------
220
+ n: int
221
+ The number of host devices.
222
+ """
223
+ xla_flags = os.getenv("XLA_FLAGS", "")
224
+ match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
225
+ return int(match.group(1)) if match else 1
226
+
227
+
228
+ def _get_precision() -> int | str:
229
+ """
230
+ Get the default precision.
231
+
232
+ Returns
233
+ -------
234
+ precision: int
235
+ The default precision.
236
+ """
237
+ return get('precision')
238
+
239
+
240
+ def get_precision() -> int:
241
+ """
242
+ Get the default precision.
243
+
244
+ Returns
245
+ -------
246
+ precision: int
247
+ The default precision.
248
+ """
249
+ precision = get('precision')
250
+ if precision == 'bf16':
251
+ return 16
252
+ if isinstance(precision, int):
253
+ return precision
254
+ if isinstance(precision, str):
255
+ return int(precision)
256
+ raise ValueError(f'Unsupported precision: {precision}')
257
+
258
+
259
+ def set(
260
+ platform: str = None,
261
+ host_device_count: int = None,
262
+ precision: int | str = None,
263
+ mode: Mode = None,
264
+ **kwargs
265
+ ):
266
+ """
267
+ Set the global default computation environment.
268
+
269
+
270
+
271
+ Args:
272
+ platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
273
+ host_device_count: int. The number of host devices.
274
+ precision: int, str. The default precision.
275
+ mode: Mode. The computing mode.
276
+ **kwargs: dict. Other environment settings.
277
+ """
278
+ if platform is not None:
279
+ set_platform(platform)
280
+ if host_device_count is not None:
281
+ set_host_device_count(host_device_count)
282
+ if precision is not None:
283
+ _set_jax_precision(precision)
284
+ kwargs['precision'] = precision
285
+ if mode is not None:
286
+ assert isinstance(mode, Mode), 'mode must be a Mode instance.'
287
+ kwargs['mode'] = mode
288
+
289
+ # set default environment
290
+ DFAULT.settings.update(kwargs)
291
+
292
+ # update the environment functions
293
+ for k, v in kwargs.items():
294
+ if k in DFAULT.functions:
295
+ DFAULT.functions[k](v)
296
+
297
+
298
+ def set_host_device_count(n):
299
+ """
300
+ By default, XLA considers all CPU cores as one device. This utility tells XLA
301
+ that there are `n` host (CPU) devices available to use. As a consequence, this
302
+ allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
303
+
304
+ .. note:: This utility only takes effect at the beginning of your program.
305
+ Under the hood, this sets the environment variable
306
+ `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
307
+ `[num_device]` is the desired number of CPU devices `n`.
308
+
309
+ .. warning:: Our understanding of the side effects of using the
310
+ `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
311
+ observe some strange phenomenon when using this utility, please let us
312
+ know through our issue or forum page. More information is available in this
313
+ `JAX issue <https://github.com/google/jax/issues/1408>`_.
314
+
315
+ :param int n: number of devices to use.
316
+ """
317
+ xla_flags = os.getenv("XLA_FLAGS", "")
318
+ xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
319
+ os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
320
+
321
+ # update the environment functions
322
+ if 'host_device_count' in DFAULT.functions:
323
+ DFAULT.functions['host_device_count'](n)
324
+
325
+
326
+ def set_platform(platform: str):
327
+ """
328
+ Changes platform to CPU, GPU, or TPU. This utility only takes
329
+ effect at the beginning of your program.
330
+
331
+ Args:
332
+ platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
333
+
334
+ Raises:
335
+ ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
336
+ """
337
+ assert platform in ['cpu', 'gpu', 'tpu']
338
+ config.update("jax_platform_name", platform)
339
+
340
+ # update the environment functions
341
+ if 'platform' in DFAULT.functions:
342
+ DFAULT.functions['platform'](platform)
343
+
344
+
345
+ def _set_jax_precision(precision: int | str):
346
+ """
347
+ Set the default precision.
348
+
349
+ Args:
350
+ precision: int. The default precision.
351
+ """
352
+ # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
353
+ if precision in [64, '64']:
354
+ config.update("jax_enable_x64", True)
355
+ else:
356
+ config.update("jax_enable_x64", False)
357
+
358
+
359
+ @functools.lru_cache()
360
+ def _get_uint(precision: int):
361
+ if precision in [64, '64']:
362
+ return np.uint64
363
+ elif precision in [32, '32']:
364
+ return np.uint32
365
+ elif precision in [16, '16', 'bf16']:
366
+ return np.uint16
367
+ elif precision in [8, '8']:
368
+ return np.uint8
369
+ else:
370
+ raise ValueError(f'Unsupported precision: {precision}')
371
+
372
+
373
+ @functools.lru_cache()
374
+ def _get_int(precision: int):
375
+ if precision in [64, '64']:
376
+ return np.int64
377
+ elif precision in [32, '32']:
378
+ return np.int32
379
+ elif precision in [16, '16', 'bf16']:
380
+ return np.int16
381
+ elif precision in [8, '8']:
382
+ return np.int8
383
+ else:
384
+ raise ValueError(f'Unsupported precision: {precision}')
385
+
386
+
387
+ @functools.lru_cache()
388
+ def _get_float(precision: int):
389
+ if precision in [64, '64']:
390
+ return np.float64
391
+ elif precision in [32, '32']:
392
+ return np.float32
393
+ elif precision in [16, '16']:
394
+ return np.float16
395
+ elif precision in ['bf16']:
396
+ return jnp.bfloat16
397
+ elif precision in [8, '8']:
398
+ return jnp.float8_e5m2
399
+ else:
400
+ raise ValueError(f'Unsupported precision: {precision}')
401
+
402
+
403
+ @functools.lru_cache()
404
+ def _get_complex(precision: int):
405
+ if precision == [64, '64']:
406
+ return np.complex128
407
+ elif precision == [32, '32']:
408
+ return np.complex64
409
+ elif precision in [16, '16', 'bf16']:
410
+ return np.complex64
411
+ elif precision == [8, '8']:
412
+ return np.complex64
413
+ else:
414
+ raise ValueError(f'Unsupported precision: {precision}')
415
+
416
+
417
+ def dftype() -> DTypeLike:
418
+ """
419
+ Default floating data type.
420
+
421
+ This function returns the default floating data type based on the current precision.
422
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
423
+ you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
424
+
425
+ For example, if the precision is set to 32, the default floating data type is ``np.float32``.
426
+
427
+ >>> import brainstate as brainstate
428
+ >>> import numpy as np
429
+ >>> with brainstate.environ.context(precision=32):
430
+ ... a = np.zeros(1, dtype=brainstate.environ.dftype())
431
+ >>> print(a.dtype)
432
+
433
+ Returns
434
+ -------
435
+ float_dtype: DTypeLike
436
+ The default floating data type.
437
+ """
438
+ return _get_float(_get_precision())
439
+
440
+
441
+ def ditype() -> DTypeLike:
442
+ """
443
+ Default integer data type.
444
+
445
+ This function returns the default integer data type based on the current precision.
446
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
447
+ you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
448
+
449
+ For example, if the precision is set to 32, the default integer data type is ``np.int32``.
450
+
451
+ >>> import brainstate as brainstate
452
+ >>> import numpy as np
453
+ >>> with brainstate.environ.context(precision=32):
454
+ ... a = np.zeros(1, dtype=brainstate.environ.ditype())
455
+ >>> print(a.dtype)
456
+ int32
457
+
458
+ Returns
459
+ -------
460
+ int_dtype: DTypeLike
461
+ The default integer data type.
462
+ """
463
+ return _get_int(_get_precision())
464
+
465
+
466
+ def dutype() -> DTypeLike:
467
+ """
468
+ Default unsigned integer data type.
469
+
470
+ This function returns the default unsigned integer data type based on the current precision.
471
+ If you want the data dtype is changed with the setting of the precision
472
+ by ``brainstate.environ.set(precision)``, you can use this function to get the default
473
+ unsigned integer data type, and create the data by using ``dtype=dutype()``.
474
+
475
+ For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
476
+
477
+ >>> import brainstate as brainstate
478
+ >>> import numpy as np
479
+ >>> with brainstate.environ.context(precision=32):
480
+ ... a = np.zeros(1, dtype=brainstate.environ.dutype())
481
+ >>> print(a.dtype)
482
+ uint32
483
+
484
+ Returns
485
+ -------
486
+ uint_dtype: DTypeLike
487
+ The default unsigned integer data type.
488
+ """
489
+ return _get_uint(_get_precision())
490
+
491
+
492
+ def dctype() -> DTypeLike:
493
+ """
494
+ Default complex data type.
495
+
496
+ This function returns the default complex data type based on the current precision.
497
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
498
+ you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
499
+
500
+ For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
501
+
502
+ >>> import brainstate as brainstate
503
+ >>> import numpy as np
504
+ >>> with brainstate.environ.context(precision=32):
505
+ ... a = np.zeros(1, dtype=brainstate.environ.dctype())
506
+ >>> print(a.dtype)
507
+ complex64
508
+
509
+ Returns
510
+ -------
511
+ complex_dtype: DTypeLike
512
+ The default complex data type.
513
+ """
514
+ return _get_complex(_get_precision())
515
+
516
+
517
+ def tolerance():
518
+ if get_precision() == 64:
519
+ return jnp.array(1e-12, dtype=np.float64)
520
+ elif get_precision() == 32:
521
+ return jnp.array(1e-5, dtype=np.float32)
522
+ else:
523
+ return jnp.array(1e-2, dtype=np.float16)
524
+
525
+
526
+ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
527
+ """
528
+ Register a default behavior for a specific global key parameter.
529
+
530
+ For example, you can register a default behavior for the key 'dt' by::
531
+
532
+ >>> import brainstate as brainstate
533
+ >>> def dt_behavior(dt):
534
+ ... print(f'Set the default dt to {dt}.')
535
+ ...
536
+ >>> brainstate.environ.register_default_behavior('dt', dt_behavior)
537
+
538
+ Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
539
+ `dt_behavior` will be called with
540
+ `dt_behavior(0.1)`.
541
+
542
+ >>> brainstate.environ.set(dt=0.1)
543
+ Set the default dt to 0.1.
544
+ >>> with brainstate.environ.context(dt=0.2):
545
+ ... pass
546
+ Set the default dt to 0.2.
547
+ Set the default dt to 0.1.
548
+
549
+
550
+ Args:
551
+ key: str. The key to register.
552
+ behavior: Callable. The behavior to register. It should be a callable.
553
+ replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
554
+
555
+ """
556
+ assert isinstance(key, str), 'key must be a string.'
557
+ assert callable(behavior), 'behavior must be a callable.'
558
+ if not replace_if_exist:
559
+ assert key not in DFAULT.functions, f'{key} has been registered.'
560
+ DFAULT.functions[key] = behavior
561
+
562
+
563
+ set(precision=32)