brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/environ.py CHANGED
@@ -1,29 +1,47 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
1
16
  # -*- coding: utf-8 -*-
2
17
 
18
+ from __future__ import annotations
3
19
 
4
20
  import contextlib
21
+ import dataclasses
5
22
  import functools
6
23
  import os
7
24
  import re
25
+ import threading
8
26
  from collections import defaultdict
9
- from typing import Any, Callable
27
+ from typing import Any, Callable, Dict, Hashable
10
28
 
11
29
  import numpy as np
12
30
  from jax import config, devices, numpy as jnp
13
- from jax._src.typing import DTypeLike
31
+ from jax.typing import DTypeLike
14
32
 
15
33
  from .mixin import Mode
16
- from .util import MemScaling, IdMemScaling
34
+ from .util import MemScaling
17
35
 
18
36
  __all__ = [
19
- # functions for environment settings
20
- 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
21
- # functions for getting default behaviors
22
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
23
- # functions for default data types
24
- 'dftype', 'ditype', 'dutype', 'dctype',
25
- # others
26
- 'tolerance', 'register_default_behavior',
37
+ # functions for environment settings
38
+ 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
39
+ # functions for getting default behaviors
40
+ 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
41
+ # functions for default data types
42
+ 'dftype', 'ditype', 'dutype', 'dctype',
43
+ # others
44
+ 'tolerance', 'register_default_behavior',
27
45
  ]
28
46
 
29
47
  # Default, there are several shared arguments in the global context.
@@ -32,196 +50,205 @@ T = 't' # the current time of the current computation.
32
50
  JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
33
51
  FIT = 'fit' # whether to fit the model.
34
52
 
53
+
54
+ @dataclasses.dataclass
55
+ class DefaultContext(threading.local):
56
+ # default environment settings
57
+ settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
58
+ # current environment settings
59
+ contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
60
+ # environment functions
61
+ functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
62
+
63
+
64
+ DFAULT = DefaultContext()
35
65
  _NOT_PROVIDE = object()
36
- _environment_defaults = dict() # default environment settings
37
- _environment_contexts = defaultdict(list) # current environment settings
38
- _environment_functions = dict() # environment functions
39
66
 
40
67
 
41
68
  @contextlib.contextmanager
42
69
  def context(**kwargs):
43
- r"""
44
- Context-manager that sets a computing environment for brain dynamics computation.
45
-
46
- In BrainPy, there are several basic computation settings when constructing models,
47
- including ``mode`` for controlling model computing behavior, ``dt`` for numerical
48
- integration, ``int_`` for integer precision, and ``float_`` for floating precision.
49
- :py:class:`~.environment`` provides a context for model construction and
50
- computation. In this temporal environment, models are constructed with the given
51
- ``mode``, ``dt``, ``int_``, etc., environment settings.
52
-
53
- For instance::
54
-
55
- >>> import brainstate as bst
56
- >>> with bst.environ.context(dt=0.1) as env:
57
- ... dt = bst.environ.get('dt')
58
- ... print(env)
59
-
60
- """
61
- if 'platform' in kwargs:
62
- raise ValueError('\n'
63
- 'Cannot set platform in "context" environment. \n'
64
- 'You should set platform in the global environment by "set_platform()" or "set()".')
65
- if 'host_device_count' in kwargs:
66
- raise ValueError('Cannot set host_device_count in environment context. '
67
- 'Please use set_host_device_count() or set() for the global setting.')
68
-
69
- if 'precision' in kwargs:
70
- last_precision = get_precision()
71
- _set_jax_precision(kwargs['precision'])
72
-
73
- try:
74
- for k, v in kwargs.items():
70
+ r"""
71
+ Context-manager that sets a computing environment for brain dynamics computation.
72
+
73
+ In BrainPy, there are several basic computation settings when constructing models,
74
+ including ``mode`` for controlling model computing behavior, ``dt`` for numerical
75
+ integration, ``int_`` for integer precision, and ``float_`` for floating precision.
76
+ :py:class:`~.environment`` provides a context for model construction and
77
+ computation. In this temporal environment, models are constructed with the given
78
+ ``mode``, ``dt``, ``int_``, etc., environment settings.
79
+
80
+ For instance::
81
+
82
+ >>> import brainstate as bst
83
+ >>> with bst.environ.context(dt=0.1) as env:
84
+ ... dt = bst.environ.get('dt')
85
+ ... print(env)
86
+
87
+ """
88
+ if 'platform' in kwargs:
89
+ raise ValueError('\n'
90
+ 'Cannot set platform in "context" environment. \n'
91
+ 'You should set platform in the global environment by "set_platform()" or "set()".')
92
+ if 'host_device_count' in kwargs:
93
+ raise ValueError('Cannot set host_device_count in environment context. '
94
+ 'Please use set_host_device_count() or set() for the global setting.')
75
95
 
76
- # update the current environment
77
- _environment_contexts[k].append(v)
96
+ if 'precision' in kwargs:
97
+ last_precision = get_precision()
98
+ _set_jax_precision(kwargs['precision'])
78
99
 
79
- # restore the environment functions
80
- if k in _environment_functions:
81
- _environment_functions[k](v)
100
+ try:
101
+ for k, v in kwargs.items():
82
102
 
83
- # yield the current all environment information
84
- yield all()
85
- finally:
103
+ # update the current environment
104
+ DFAULT.contexts[k].append(v)
86
105
 
87
- for k, v in kwargs.items():
106
+ # restore the environment functions
107
+ if k in DFAULT.functions:
108
+ DFAULT.functions[k](v)
88
109
 
89
- # restore the current environment
90
- _environment_contexts[k].pop()
110
+ # yield the current all environment information
111
+ yield all()
112
+ finally:
91
113
 
92
- # restore the environment functions
93
- if k in _environment_functions:
94
- _environment_functions[k](get(k))
114
+ for k, v in kwargs.items():
95
115
 
96
- if 'precision' in kwargs:
97
- _set_jax_precision(last_precision)
116
+ # restore the current environment
117
+ DFAULT.contexts[k].pop()
118
+
119
+ # restore the environment functions
120
+ if k in DFAULT.functions:
121
+ DFAULT.functions[k](get(k))
122
+
123
+ if 'precision' in kwargs:
124
+ _set_jax_precision(last_precision)
98
125
 
99
126
 
100
127
  def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
101
- """
102
- Get one of the default computation environment.
103
-
104
- Returns
105
- -------
106
- item: Any
107
- The default computation environment.
108
- """
109
- if key == 'platform':
110
- return get_platform()
111
-
112
- if key == 'host_device_count':
113
- return get_host_device_count()
114
-
115
- if key in _environment_contexts:
116
- if len(_environment_contexts[key]) > 0:
117
- return _environment_contexts[key][-1]
118
- if key in _environment_defaults:
119
- return _environment_defaults[key]
120
-
121
- if default is _NOT_PROVIDE:
122
- if desc is not None:
123
- raise KeyError(
124
- f"'{key}' is not found in the context. \n"
125
- f"You can set it by `brainstate.share.context({key}=value)` "
126
- f"locally or `brainstate.share.set({key}=value)` globally. \n"
127
- f"Description: {desc}"
128
- )
129
- else:
130
- raise KeyError(
131
- f"'{key}' is not found in the context. \n"
132
- f"You can set it by `brainstate.share.context({key}=value)` "
133
- f"locally or `brainstate.share.set({key}=value)` globally."
134
- )
135
- return default
128
+ """
129
+ Get one of the default computation environment.
130
+
131
+ Returns
132
+ -------
133
+ item: Any
134
+ The default computation environment.
135
+ """
136
+ if key == 'platform':
137
+ return get_platform()
138
+
139
+ if key == 'host_device_count':
140
+ return get_host_device_count()
141
+
142
+ if key in DFAULT.contexts:
143
+ if len(DFAULT.contexts[key]) > 0:
144
+ return DFAULT.contexts[key][-1]
145
+ if key in DFAULT.settings:
146
+ return DFAULT.settings[key]
147
+
148
+ if default is _NOT_PROVIDE:
149
+ if desc is not None:
150
+ raise KeyError(
151
+ f"'{key}' is not found in the context. \n"
152
+ f"You can set it by `brainstate.share.context({key}=value)` "
153
+ f"locally or `brainstate.share.set({key}=value)` globally. \n"
154
+ f"Description: {desc}"
155
+ )
156
+ else:
157
+ raise KeyError(
158
+ f"'{key}' is not found in the context. \n"
159
+ f"You can set it by `brainstate.share.context({key}=value)` "
160
+ f"locally or `brainstate.share.set({key}=value)` globally."
161
+ )
162
+ return default
136
163
 
137
164
 
138
165
  def all() -> dict:
139
- """
140
- Get all the current default computation environment.
141
-
142
- Returns
143
- -------
144
- r: dict
145
- The current default computation environment.
146
- """
147
- r = dict()
148
- for k, v in _environment_contexts.items():
149
- if v:
150
- r[k] = v[-1]
151
- for k, v in _environment_defaults.items():
152
- if k not in r:
153
- r[k] = v
154
- return r
166
+ """
167
+ Get all the current default computation environment.
168
+
169
+ Returns
170
+ -------
171
+ r: dict
172
+ The current default computation environment.
173
+ """
174
+ r = dict()
175
+ for k, v in DFAULT.contexts.items():
176
+ if v:
177
+ r[k] = v[-1]
178
+ for k, v in DFAULT.settings.items():
179
+ if k not in r:
180
+ r[k] = v
181
+ return r
155
182
 
156
183
 
157
184
  def get_dt():
158
- """Get the numerical integrator precision.
185
+ """Get the numerical integrator precision.
159
186
 
160
- Returns
161
- -------
162
- dt : float
163
- Numerical integration precision.
164
- """
165
- return get('dt')
187
+ Returns
188
+ -------
189
+ dt : float
190
+ Numerical integration precision.
191
+ """
192
+ return get('dt')
166
193
 
167
194
 
168
195
  def get_mode() -> Mode:
169
- """Get the default computing mode.
196
+ """Get the default computing mode.
170
197
 
171
- References
172
- ----------
173
- mode: Mode
174
- The default computing mode.
175
- """
176
- return get('mode')
198
+ References
199
+ ----------
200
+ mode: Mode
201
+ The default computing mode.
202
+ """
203
+ return get('mode')
177
204
 
178
205
 
179
206
  def get_mem_scaling() -> MemScaling:
180
- """Get the default computing membrane_scaling.
207
+ """Get the default computing membrane_scaling.
181
208
 
182
- Returns
183
- -------
184
- membrane_scaling: MemScaling
185
- The default computing membrane_scaling.
186
- """
187
- return get('mem_scaling')
209
+ Returns
210
+ -------
211
+ membrane_scaling: MemScaling
212
+ The default computing membrane_scaling.
213
+ """
214
+ return get('mem_scaling')
188
215
 
189
216
 
190
217
  def get_platform() -> str:
191
- """Get the computing platform.
218
+ """Get the computing platform.
192
219
 
193
- Returns
194
- -------
195
- platform: str
196
- Either 'cpu', 'gpu' or 'tpu'.
197
- """
198
- return devices()[0].platform
220
+ Returns
221
+ -------
222
+ platform: str
223
+ Either 'cpu', 'gpu' or 'tpu'.
224
+ """
225
+ return devices()[0].platform
199
226
 
200
227
 
201
228
  def get_host_device_count():
202
- """
203
- Get the number of host devices.
229
+ """
230
+ Get the number of host devices.
204
231
 
205
- Returns
206
- -------
207
- n: int
208
- The number of host devices.
209
- """
210
- xla_flags = os.getenv("XLA_FLAGS", "")
211
- match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
212
- return int(match.group(1)) if match else 1
232
+ Returns
233
+ -------
234
+ n: int
235
+ The number of host devices.
236
+ """
237
+ xla_flags = os.getenv("XLA_FLAGS", "")
238
+ match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
239
+ return int(match.group(1)) if match else 1
213
240
 
214
241
 
215
242
  def get_precision() -> int:
216
- """
217
- Get the default precision.
243
+ """
244
+ Get the default precision.
218
245
 
219
- Returns
220
- -------
221
- precision: int
222
- The default precision.
223
- """
224
- return get('precision')
246
+ Returns
247
+ -------
248
+ precision: int
249
+ The default precision.
250
+ """
251
+ return get('precision')
225
252
 
226
253
 
227
254
  def set(
@@ -232,300 +259,300 @@ def set(
232
259
  mode: Mode = None,
233
260
  **kwargs
234
261
  ):
235
- """
236
- Set the global default computation environment.
237
-
238
-
239
-
240
- Args:
241
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
242
- host_device_count: int. The number of host devices.
243
- mem_scaling: MemScaling. The membrane scaling.
244
- precision: int. The default precision.
245
- mode: Mode. The computing mode.
246
- **kwargs: dict. Other environment settings.
247
- """
248
- if platform is not None:
249
- set_platform(platform)
250
- if host_device_count is not None:
251
- set_host_device_count(host_device_count)
252
- if mem_scaling is not None:
253
- assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
254
- kwargs['mem_scaling'] = mem_scaling
255
- if precision is not None:
256
- _set_jax_precision(precision)
257
- kwargs['precision'] = precision
258
- if mode is not None:
259
- assert isinstance(mode, Mode), 'mode must be a Mode instance.'
260
- kwargs['mode'] = mode
261
-
262
- # set default environment
263
- _environment_defaults.update(kwargs)
264
-
265
- # update the environment functions
266
- for k, v in kwargs.items():
267
- if k in _environment_functions:
268
- _environment_functions[k](v)
262
+ """
263
+ Set the global default computation environment.
264
+
265
+
266
+
267
+ Args:
268
+ platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
269
+ host_device_count: int. The number of host devices.
270
+ mem_scaling: MemScaling. The membrane scaling.
271
+ precision: int. The default precision.
272
+ mode: Mode. The computing mode.
273
+ **kwargs: dict. Other environment settings.
274
+ """
275
+ if platform is not None:
276
+ set_platform(platform)
277
+ if host_device_count is not None:
278
+ set_host_device_count(host_device_count)
279
+ if mem_scaling is not None:
280
+ assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
281
+ kwargs['mem_scaling'] = mem_scaling
282
+ if precision is not None:
283
+ _set_jax_precision(precision)
284
+ kwargs['precision'] = precision
285
+ if mode is not None:
286
+ assert isinstance(mode, Mode), 'mode must be a Mode instance.'
287
+ kwargs['mode'] = mode
288
+
289
+ # set default environment
290
+ DFAULT.settings.update(kwargs)
291
+
292
+ # update the environment functions
293
+ for k, v in kwargs.items():
294
+ if k in DFAULT.functions:
295
+ DFAULT.functions[k](v)
269
296
 
270
297
 
271
298
  def set_host_device_count(n):
272
- """
273
- By default, XLA considers all CPU cores as one device. This utility tells XLA
274
- that there are `n` host (CPU) devices available to use. As a consequence, this
275
- allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
299
+ """
300
+ By default, XLA considers all CPU cores as one device. This utility tells XLA
301
+ that there are `n` host (CPU) devices available to use. As a consequence, this
302
+ allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
276
303
 
277
- .. note:: This utility only takes effect at the beginning of your program.
278
- Under the hood, this sets the environment variable
279
- `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
280
- `[num_device]` is the desired number of CPU devices `n`.
304
+ .. note:: This utility only takes effect at the beginning of your program.
305
+ Under the hood, this sets the environment variable
306
+ `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
307
+ `[num_device]` is the desired number of CPU devices `n`.
281
308
 
282
- .. warning:: Our understanding of the side effects of using the
283
- `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
284
- observe some strange phenomenon when using this utility, please let us
285
- know through our issue or forum page. More information is available in this
286
- `JAX issue <https://github.com/google/jax/issues/1408>`_.
309
+ .. warning:: Our understanding of the side effects of using the
310
+ `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
311
+ observe some strange phenomenon when using this utility, please let us
312
+ know through our issue or forum page. More information is available in this
313
+ `JAX issue <https://github.com/google/jax/issues/1408>`_.
287
314
 
288
- :param int n: number of devices to use.
289
- """
290
- xla_flags = os.getenv("XLA_FLAGS", "")
291
- xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
292
- os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
315
+ :param int n: number of devices to use.
316
+ """
317
+ xla_flags = os.getenv("XLA_FLAGS", "")
318
+ xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
319
+ os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
293
320
 
294
- # update the environment functions
295
- if 'host_device_count' in _environment_functions:
296
- _environment_functions['host_device_count'](n)
321
+ # update the environment functions
322
+ if 'host_device_count' in DFAULT.functions:
323
+ DFAULT.functions['host_device_count'](n)
297
324
 
298
325
 
299
326
  def set_platform(platform: str):
300
- """
301
- Changes platform to CPU, GPU, or TPU. This utility only takes
302
- effect at the beginning of your program.
327
+ """
328
+ Changes platform to CPU, GPU, or TPU. This utility only takes
329
+ effect at the beginning of your program.
303
330
 
304
- Args:
305
- platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
331
+ Args:
332
+ platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
306
333
 
307
- Raises:
308
- ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
309
- """
310
- assert platform in ['cpu', 'gpu', 'tpu']
311
- config.update("jax_platform_name", platform)
334
+ Raises:
335
+ ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
336
+ """
337
+ assert platform in ['cpu', 'gpu', 'tpu']
338
+ config.update("jax_platform_name", platform)
312
339
 
313
- # update the environment functions
314
- if 'platform' in _environment_functions:
315
- _environment_functions['platform'](platform)
340
+ # update the environment functions
341
+ if 'platform' in DFAULT.functions:
342
+ DFAULT.functions['platform'](platform)
316
343
 
317
344
 
318
345
  def _set_jax_precision(precision: int):
319
- """
320
- Set the default precision.
321
-
322
- Args:
323
- precision: int. The default precision.
324
- """
325
- assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
326
- if precision == 64:
327
- config.update("jax_enable_x64", True)
328
- else:
329
- config.update("jax_enable_x64", False)
346
+ """
347
+ Set the default precision.
348
+
349
+ Args:
350
+ precision: int. The default precision.
351
+ """
352
+ assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
353
+ if precision == 64:
354
+ config.update("jax_enable_x64", True)
355
+ else:
356
+ config.update("jax_enable_x64", False)
330
357
 
331
358
 
332
359
  @functools.lru_cache()
333
360
  def _get_uint(precision: int):
334
- if precision == 64:
335
- return np.uint64
336
- elif precision == 32:
337
- return np.uint32
338
- elif precision == 16:
339
- return np.uint16
340
- elif precision == 8:
341
- return np.uint8
342
- else:
343
- raise ValueError(f'Unsupported precision: {precision}')
361
+ if precision == 64:
362
+ return np.uint64
363
+ elif precision == 32:
364
+ return np.uint32
365
+ elif precision == 16:
366
+ return np.uint16
367
+ elif precision == 8:
368
+ return np.uint8
369
+ else:
370
+ raise ValueError(f'Unsupported precision: {precision}')
344
371
 
345
372
 
346
373
  @functools.lru_cache()
347
374
  def _get_int(precision: int):
348
- if precision == 64:
349
- return np.int64
350
- elif precision == 32:
351
- return np.int32
352
- elif precision == 16:
353
- return np.int16
354
- elif precision == 8:
355
- return np.int8
356
- else:
357
- raise ValueError(f'Unsupported precision: {precision}')
375
+ if precision == 64:
376
+ return np.int64
377
+ elif precision == 32:
378
+ return np.int32
379
+ elif precision == 16:
380
+ return np.int16
381
+ elif precision == 8:
382
+ return np.int8
383
+ else:
384
+ raise ValueError(f'Unsupported precision: {precision}')
358
385
 
359
386
 
360
387
  @functools.lru_cache()
361
388
  def _get_float(precision: int):
362
- if precision == 64:
363
- return np.float64
364
- elif precision == 32:
365
- return np.float32
366
- elif precision == 16:
367
- return jnp.bfloat16
368
- # return np.float16
369
- else:
370
- raise ValueError(f'Unsupported precision: {precision}')
389
+ if precision == 64:
390
+ return np.float64
391
+ elif precision == 32:
392
+ return np.float32
393
+ elif precision == 16:
394
+ return jnp.bfloat16
395
+ # return np.float16
396
+ else:
397
+ raise ValueError(f'Unsupported precision: {precision}')
371
398
 
372
399
 
373
400
  @functools.lru_cache()
374
401
  def _get_complex(precision: int):
375
- if precision == 64:
376
- return np.complex128
377
- elif precision == 32:
378
- return np.complex64
379
- elif precision == 16:
380
- return np.complex32
381
- else:
382
- raise ValueError(f'Unsupported precision: {precision}')
402
+ if precision == 64:
403
+ return np.complex128
404
+ elif precision == 32:
405
+ return np.complex64
406
+ elif precision == 16:
407
+ return np.complex32
408
+ else:
409
+ raise ValueError(f'Unsupported precision: {precision}')
383
410
 
384
411
 
385
412
  def dftype() -> DTypeLike:
386
- """
387
- Default floating data type.
413
+ """
414
+ Default floating data type.
388
415
 
389
- This function returns the default floating data type based on the current precision.
390
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
391
- you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
416
+ This function returns the default floating data type based on the current precision.
417
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
418
+ you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
392
419
 
393
- For example, if the precision is set to 32, the default floating data type is ``np.float32``.
420
+ For example, if the precision is set to 32, the default floating data type is ``np.float32``.
394
421
 
395
- >>> import brainstate as bst
396
- >>> import numpy as np
397
- >>> with bst.environ.context(precision=32):
398
- ... a = np.zeros(1, dtype=bst.environ.dftype())
399
- >>> print(a.dtype)
422
+ >>> import brainstate as bst
423
+ >>> import numpy as np
424
+ >>> with bst.environ.context(precision=32):
425
+ ... a = np.zeros(1, dtype=bst.environ.dftype())
426
+ >>> print(a.dtype)
400
427
 
401
- Returns
402
- -------
403
- float_dtype: DTypeLike
404
- The default floating data type.
405
- """
406
- return _get_float(get_precision())
428
+ Returns
429
+ -------
430
+ float_dtype: DTypeLike
431
+ The default floating data type.
432
+ """
433
+ return _get_float(get_precision())
407
434
 
408
435
 
409
436
  def ditype() -> DTypeLike:
410
- """
411
- Default integer data type.
437
+ """
438
+ Default integer data type.
412
439
 
413
- This function returns the default integer data type based on the current precision.
414
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
415
- you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
440
+ This function returns the default integer data type based on the current precision.
441
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
442
+ you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
416
443
 
417
- For example, if the precision is set to 32, the default integer data type is ``np.int32``.
444
+ For example, if the precision is set to 32, the default integer data type is ``np.int32``.
418
445
 
419
- >>> import brainstate as bst
420
- >>> import numpy as np
421
- >>> with bst.environ.context(precision=32):
422
- ... a = np.zeros(1, dtype=bst.environ.ditype())
423
- >>> print(a.dtype)
424
- int32
446
+ >>> import brainstate as bst
447
+ >>> import numpy as np
448
+ >>> with bst.environ.context(precision=32):
449
+ ... a = np.zeros(1, dtype=bst.environ.ditype())
450
+ >>> print(a.dtype)
451
+ int32
425
452
 
426
- Returns
427
- -------
428
- int_dtype: DTypeLike
429
- The default integer data type.
430
- """
431
- return _get_int(get_precision())
453
+ Returns
454
+ -------
455
+ int_dtype: DTypeLike
456
+ The default integer data type.
457
+ """
458
+ return _get_int(get_precision())
432
459
 
433
460
 
434
461
  def dutype() -> DTypeLike:
435
- """
436
- Default unsigned integer data type.
462
+ """
463
+ Default unsigned integer data type.
437
464
 
438
- This function returns the default unsigned integer data type based on the current precision.
439
- If you want the data dtype is changed with the setting of the precision
440
- by ``brainstate.environ.set(precision)``, you can use this function to get the default
441
- unsigned integer data type, and create the data by using ``dtype=dutype()``.
465
+ This function returns the default unsigned integer data type based on the current precision.
466
+ If you want the data dtype is changed with the setting of the precision
467
+ by ``brainstate.environ.set(precision)``, you can use this function to get the default
468
+ unsigned integer data type, and create the data by using ``dtype=dutype()``.
442
469
 
443
- For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
470
+ For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
444
471
 
445
- >>> import brainstate as bst
446
- >>> import numpy as np
447
- >>> with bst.environ.context(precision=32):
448
- ... a = np.zeros(1, dtype=bst.environ.dutype())
449
- >>> print(a.dtype)
450
- uint32
472
+ >>> import brainstate as bst
473
+ >>> import numpy as np
474
+ >>> with bst.environ.context(precision=32):
475
+ ... a = np.zeros(1, dtype=bst.environ.dutype())
476
+ >>> print(a.dtype)
477
+ uint32
451
478
 
452
- Returns
453
- -------
454
- uint_dtype: DTypeLike
455
- The default unsigned integer data type.
456
- """
457
- return _get_uint(get_precision())
479
+ Returns
480
+ -------
481
+ uint_dtype: DTypeLike
482
+ The default unsigned integer data type.
483
+ """
484
+ return _get_uint(get_precision())
458
485
 
459
486
 
460
487
  def dctype() -> DTypeLike:
461
- """
462
- Default complex data type.
488
+ """
489
+ Default complex data type.
463
490
 
464
- This function returns the default complex data type based on the current precision.
465
- If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
466
- you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
491
+ This function returns the default complex data type based on the current precision.
492
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
493
+ you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
467
494
 
468
- For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
495
+ For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
469
496
 
470
- >>> import brainstate as bst
471
- >>> import numpy as np
472
- >>> with bst.environ.context(precision=32):
473
- ... a = np.zeros(1, dtype=bst.environ.dctype())
474
- >>> print(a.dtype)
475
- complex64
497
+ >>> import brainstate as bst
498
+ >>> import numpy as np
499
+ >>> with bst.environ.context(precision=32):
500
+ ... a = np.zeros(1, dtype=bst.environ.dctype())
501
+ >>> print(a.dtype)
502
+ complex64
476
503
 
477
- Returns
478
- -------
479
- complex_dtype: DTypeLike
480
- The default complex data type.
481
- """
482
- return _get_complex(get_precision())
504
+ Returns
505
+ -------
506
+ complex_dtype: DTypeLike
507
+ The default complex data type.
508
+ """
509
+ return _get_complex(get_precision())
483
510
 
484
511
 
485
512
  def tolerance():
486
- if get_precision() == 64:
487
- return jnp.array(1e-12, dtype=np.float64)
488
- elif get_precision() == 32:
489
- return jnp.array(1e-5, dtype=np.float32)
490
- else:
491
- return jnp.array(1e-2, dtype=np.float16)
513
+ if get_precision() == 64:
514
+ return jnp.array(1e-12, dtype=np.float64)
515
+ elif get_precision() == 32:
516
+ return jnp.array(1e-5, dtype=np.float32)
517
+ else:
518
+ return jnp.array(1e-2, dtype=np.float16)
492
519
 
493
520
 
494
521
  def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
495
- """
496
- Register a default behavior for a specific global key parameter.
522
+ """
523
+ Register a default behavior for a specific global key parameter.
497
524
 
498
- For example, you can register a default behavior for the key 'dt' by::
525
+ For example, you can register a default behavior for the key 'dt' by::
499
526
 
500
- >>> import brainstate as bst
501
- >>> def dt_behavior(dt):
502
- ... print(f'Set the default dt to {dt}.')
503
- ...
504
- >>> bst.environ.register_default_behavior('dt', dt_behavior)
527
+ >>> import brainstate as bst
528
+ >>> def dt_behavior(dt):
529
+ ... print(f'Set the default dt to {dt}.')
530
+ ...
531
+ >>> bst.environ.register_default_behavior('dt', dt_behavior)
505
532
 
506
- Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
507
- `dt_behavior` will be called with
508
- `dt_behavior(0.1)`.
533
+ Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
534
+ `dt_behavior` will be called with
535
+ `dt_behavior(0.1)`.
509
536
 
510
- >>> bst.environ.set(dt=0.1)
511
- Set the default dt to 0.1.
512
- >>> with bst.environ.context(dt=0.2):
513
- ... pass
514
- Set the default dt to 0.2.
515
- Set the default dt to 0.1.
537
+ >>> bst.environ.set(dt=0.1)
538
+ Set the default dt to 0.1.
539
+ >>> with bst.environ.context(dt=0.2):
540
+ ... pass
541
+ Set the default dt to 0.2.
542
+ Set the default dt to 0.1.
516
543
 
517
544
 
518
- Args:
519
- key: str. The key to register.
520
- behavior: Callable. The behavior to register. It should be a callable.
521
- replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
545
+ Args:
546
+ key: str. The key to register.
547
+ behavior: Callable. The behavior to register. It should be a callable.
548
+ replace_if_exist: bool. Whether to replace the behavior if the key has been registered.
522
549
 
523
- """
524
- assert isinstance(key, str), 'key must be a string.'
525
- assert callable(behavior), 'behavior must be a callable.'
526
- if not replace_if_exist:
527
- assert key not in _environment_functions, f'{key} has been registered.'
528
- _environment_functions[key] = behavior
550
+ """
551
+ assert isinstance(key, str), 'key must be a string.'
552
+ assert callable(behavior), 'behavior must be a callable.'
553
+ if not replace_if_exist:
554
+ assert key not in DFAULT.functions, f'{key} has been registered.'
555
+ DFAULT.functions[key] = behavior
529
556
 
530
557
 
531
- set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
558
+ set(precision=32)