brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/environ.py ADDED
@@ -0,0 +1,375 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+
4
+ import contextlib
5
+ import functools
6
+ import os
7
+ import re
8
+ from collections import defaultdict
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ from jax import config, devices, numpy as jnp
13
+ from jax._src.typing import DTypeLike
14
+
15
+ from .mixin import Mode
16
+ from .util import MemScaling, IdMemScaling
17
+
18
+ __all__ = [
19
+ 'set', 'context', 'get', 'all',
20
+ 'set_host_device_count', 'set_platform',
21
+ 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
22
+ 'tolerance',
23
+ 'dftype', 'ditype', 'dutype', 'dctype',
24
+ ]
25
+
26
+
27
+ # Default, there are several shared arguments in the global context.
28
+ I = 'i' # the index of the current computation.
29
+ T = 't' # the current time of the current computation.
30
+ JIT_ERROR_CHECK = 'jit_error_check' # whether to record the current computation.
31
+ FIT = 'fit' # whether to fit the model.
32
+
33
+ _NOT_PROVIDE = object()
34
+ _environment_defaults = dict()
35
+ _environment_contexts = defaultdict(list)
36
+
37
+
38
+ @contextlib.contextmanager
39
+ def context(**kwargs):
40
+ r"""
41
+ Context-manager that sets a computing environment for brain dynamics computation.
42
+
43
+ In BrainPy, there are several basic computation settings when constructing models,
44
+ including ``mode`` for controlling model computing behavior, ``dt`` for numerical
45
+ integration, ``int_`` for integer precision, and ``float_`` for floating precision.
46
+ :py:class:`~.environment`` provides a context for model construction and
47
+ computation. In this temporal environment, models are constructed with the given
48
+ ``mode``, ``dt``, ``int_``, etc., environment settings.
49
+
50
+ For instance::
51
+
52
+ >>> import brainstate as bst
53
+ >>> with bst.environ.context(dt=0.1) as env:
54
+ ... dt = bst.environ.get('dt')
55
+ ... print(env)
56
+
57
+ """
58
+ if 'platform' in kwargs:
59
+ raise ValueError('Cannot set platform in environment context. '
60
+ 'Please use set_platform() or set() for the global setting.')
61
+ if 'host_device_count' in kwargs:
62
+ raise ValueError('Cannot set host_device_count in environment context. '
63
+ 'Please use set_host_device_count() or set() for the global setting.')
64
+ if 'precision' in kwargs:
65
+ last_precision = get_precision()
66
+ _set_jax_precision(kwargs['precision'])
67
+
68
+ try:
69
+ # update the current environment
70
+ for k, v in kwargs.items():
71
+ _environment_contexts[k].append(v)
72
+ # yield the current all environment information
73
+ yield all()
74
+ finally:
75
+ for k, v in kwargs.items():
76
+ _environment_contexts[k].pop()
77
+ if 'precision' in kwargs:
78
+ _set_jax_precision(last_precision)
79
+
80
+
81
+ def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
82
+ """
83
+ Get one of the default computation environment.
84
+
85
+ Returns
86
+ -------
87
+ item: Any
88
+ The default computation environment.
89
+ """
90
+ if key == 'platform':
91
+ return get_platform()
92
+
93
+ if key == 'host_device_count':
94
+ return get_host_device_count()
95
+
96
+ if key in _environment_contexts:
97
+ if len(_environment_contexts[key]) > 0:
98
+ return _environment_contexts[key][-1]
99
+ if key in _environment_defaults:
100
+ return _environment_defaults[key]
101
+
102
+ if default is _NOT_PROVIDE:
103
+ if desc is not None:
104
+ raise KeyError(
105
+ f"'{key}' is not found in the context. \n"
106
+ f"You can set it by `brainstate.share.context({key}=value)` "
107
+ f"locally or `brainstate.share.set({key}=value)` globally. \n"
108
+ f"Description: {desc}"
109
+ )
110
+ else:
111
+ raise KeyError(
112
+ f"'{key}' is not found in the context. \n"
113
+ f"You can set it by `brainstate.share.context({key}=value)` "
114
+ f"locally or `brainstate.share.set({key}=value)` globally."
115
+ )
116
+ return default
117
+
118
+
119
+ def all() -> dict:
120
+ """
121
+ Get all the current default computation environment.
122
+ """
123
+ r = dict()
124
+ for k, v in _environment_contexts.items():
125
+ if v:
126
+ r[k] = v[-1]
127
+ for k, v in _environment_defaults.items():
128
+ if k not in r:
129
+ r[k] = v
130
+ return r
131
+
132
+
133
+ def get_dt():
134
+ """Get the numerical integrator precision.
135
+
136
+ Returns
137
+ -------
138
+ dt : float
139
+ Numerical integration precision.
140
+ """
141
+ return get('dt')
142
+
143
+
144
+ def get_mode() -> Mode:
145
+ """Get the default computing mode.
146
+
147
+ References
148
+ ----------
149
+ mode: Mode
150
+ The default computing mode.
151
+ """
152
+ return get('mode')
153
+
154
+
155
+ def get_mem_scaling() -> MemScaling:
156
+ """Get the default computing membrane_scaling.
157
+
158
+ Returns
159
+ -------
160
+ membrane_scaling: MemScaling
161
+ The default computing membrane_scaling.
162
+ """
163
+ return get('mem_scaling')
164
+
165
+
166
+ def get_platform() -> str:
167
+ """Get the computing platform.
168
+
169
+ Returns
170
+ -------
171
+ platform: str
172
+ Either 'cpu', 'gpu' or 'tpu'.
173
+ """
174
+ return devices()[0].platform
175
+
176
+
177
+ def get_host_device_count():
178
+ """
179
+ Get the number of host devices.
180
+
181
+ Returns
182
+ -------
183
+ n: int
184
+ The number of host devices.
185
+ """
186
+ xla_flags = os.getenv("XLA_FLAGS", "")
187
+ match = re.search(r"--xla_force_host_platform_device_count=(\d+)", xla_flags)
188
+ return int(match.group(1)) if match else 1
189
+
190
+
191
+ def get_precision() -> int:
192
+ """
193
+ Get the default precision.
194
+
195
+ Returns
196
+ -------
197
+ precision: int
198
+ The default precision.
199
+ """
200
+ return get('precision')
201
+
202
+
203
+ def set(
204
+ platform: str = None,
205
+ host_device_count: int = None,
206
+ mem_scaling: MemScaling = None,
207
+ precision: int = None,
208
+ mode: Mode = None,
209
+ **kwargs
210
+ ):
211
+ """
212
+ Set the global default computation environment.
213
+
214
+ Args:
215
+ platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
216
+ host_device_count: int. The number of host devices.
217
+ mem_scaling: MemScaling. The membrane scaling.
218
+ precision: int. The default precision.
219
+ mode: Mode. The computing mode.
220
+ **kwargs: dict. Other environment settings.
221
+ """
222
+ if platform is not None:
223
+ set_platform(platform)
224
+ if host_device_count is not None:
225
+ set_host_device_count(host_device_count)
226
+ if mem_scaling is not None:
227
+ assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
228
+ kwargs['mem_scaling'] = mem_scaling
229
+ if precision is not None:
230
+ _set_jax_precision(precision)
231
+ kwargs['precision'] = precision
232
+ if mode is not None:
233
+ assert isinstance(mode, Mode), 'mode must be a Mode instance.'
234
+ kwargs['mode'] = mode
235
+ _environment_defaults.update(kwargs)
236
+
237
+
238
+ def set_host_device_count(n):
239
+ """
240
+ By default, XLA considers all CPU cores as one device. This utility tells XLA
241
+ that there are `n` host (CPU) devices available to use. As a consequence, this
242
+ allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
243
+
244
+ .. note:: This utility only takes effect at the beginning of your program.
245
+ Under the hood, this sets the environment variable
246
+ `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
247
+ `[num_device]` is the desired number of CPU devices `n`.
248
+
249
+ .. warning:: Our understanding of the side effects of using the
250
+ `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
251
+ observe some strange phenomenon when using this utility, please let us
252
+ know through our issue or forum page. More information is available in this
253
+ `JAX issue <https://github.com/google/jax/issues/1408>`_.
254
+
255
+ :param int n: number of devices to use.
256
+ """
257
+ xla_flags = os.getenv("XLA_FLAGS", "")
258
+ xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
259
+ os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
260
+
261
+
262
+ def set_platform(platform: str):
263
+ """
264
+ Changes platform to CPU, GPU, or TPU. This utility only takes
265
+ effect at the beginning of your program.
266
+ """
267
+ assert platform in ['cpu', 'gpu', 'tpu']
268
+ config.update("jax_platform_name", platform)
269
+
270
+
271
+ def _set_jax_precision(precision: int):
272
+ """
273
+ Set the default precision.
274
+
275
+ Args:
276
+ precision: int. The default precision.
277
+ """
278
+ assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
279
+ if precision == 64:
280
+ config.update("jax_enable_x64", True)
281
+ else:
282
+ config.update("jax_enable_x64", False)
283
+
284
+
285
+ @functools.lru_cache()
286
+ def _get_uint(precision: int):
287
+ if precision == 64:
288
+ return np.uint64
289
+ elif precision == 32:
290
+ return np.uint32
291
+ elif precision == 16:
292
+ return np.uint16
293
+ elif precision == 8:
294
+ return np.uint8
295
+ else:
296
+ raise ValueError(f'Unsupported precision: {precision}')
297
+
298
+
299
+ @functools.lru_cache()
300
+ def _get_int(precision: int):
301
+ if precision == 64:
302
+ return np.int64
303
+ elif precision == 32:
304
+ return np.int32
305
+ elif precision == 16:
306
+ return np.int16
307
+ elif precision == 8:
308
+ return np.int8
309
+ else:
310
+ raise ValueError(f'Unsupported precision: {precision}')
311
+
312
+
313
+ @functools.lru_cache()
314
+ def _get_float(precision: int):
315
+ if precision == 64:
316
+ return np.float64
317
+ elif precision == 32:
318
+ return np.float32
319
+ elif precision == 16:
320
+ return jnp.bfloat16
321
+ # return np.float16
322
+ else:
323
+ raise ValueError(f'Unsupported precision: {precision}')
324
+
325
+
326
+ @functools.lru_cache()
327
+ def _get_complex(precision: int):
328
+ if precision == 64:
329
+ return np.complex128
330
+ elif precision == 32:
331
+ return np.complex64
332
+ elif precision == 16:
333
+ return np.complex32
334
+ else:
335
+ raise ValueError(f'Unsupported precision: {precision}')
336
+
337
+
338
+ def dftype() -> DTypeLike:
339
+ """
340
+ Default floating data type.
341
+ """
342
+ return _get_float(get_precision())
343
+
344
+
345
+ def ditype() -> DTypeLike:
346
+ """
347
+ Default integer data type.
348
+ """
349
+ return _get_int(get_precision())
350
+
351
+
352
+ def dutype() -> DTypeLike:
353
+ """
354
+ Default unsigned integer data type.
355
+ """
356
+ return _get_uint(get_precision())
357
+
358
+
359
+ def dctype() -> DTypeLike:
360
+ """
361
+ Default complex data type.
362
+ """
363
+ return _get_complex(get_precision())
364
+
365
+
366
+ def tolerance():
367
+ if get_precision() == 64:
368
+ return jnp.array(1e-12, dtype=np.float64)
369
+ elif get_precision() == 32:
370
+ return jnp.array(1e-5, dtype=np.float32)
371
+ else:
372
+ return jnp.array(1e-2, dtype=np.float16)
373
+
374
+
375
+ set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
@@ -0,0 +1,25 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._activations import *
18
+ from ._activations import __all__ as __activations_all__
19
+ from ._normalization import *
20
+ from ._normalization import __all__ as __others_all__
21
+ from ._spikes import *
22
+ from ._spikes import __all__ as __spikes_all__
23
+
24
+ __all__ = __spikes_all__ + __others_all__ + __activations_all__
25
+