brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241010__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1498 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241010.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -30,6 +30,8 @@ from . import surrogate
30
30
  from . import transform
31
31
  from . import typing
32
32
  from . import util
33
+ from ._visualization import *
34
+ from ._visualization import __all__ as _visualization_all
33
35
  from ._module import *
34
36
  from ._module import __all__ as _module_all
35
37
  from ._state import *
@@ -39,6 +41,6 @@ __all__ = (
39
41
  ['environ', 'share', 'nn', 'optim', 'random',
40
42
  'surrogate', 'functional', 'init',
41
43
  'mixin', 'transform', 'util', 'typing'] +
42
- _module_all + _state_all
44
+ _module_all + _state_all + _visualization_all
43
45
  )
44
- del _module_all, _state_all
46
+ del _module_all, _state_all, _visualization_all
brainstate/_module.py CHANGED
@@ -56,14 +56,13 @@ import jax
56
56
  import jax.numpy as jnp
57
57
  import numpy as np
58
58
 
59
- from . import environ
60
- from ._state import State, StateDictManager, visible_state_dict
61
- from ._utils import set_module_as
62
- from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
- from .transform import jit_error_if
64
- from .typing import Size, ArrayLike, PyTree
65
- from .util import unique_name, DictManager, get_unique_name
66
-
59
+ from brainstate import environ
60
+ from brainstate._state import State, StateDictManager, visible_state_dict
61
+ from brainstate._utils import set_module_as
62
+ from brainstate.mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
+ from brainstate.transform import jit_error_if
64
+ from brainstate.typing import Size, ArrayLike, PyTree
65
+ from brainstate.util import unique_name, DictManager, get_unique_name
67
66
 
68
67
  delay_identifier = '_*_delay_of_'
69
68
  _DELAY_ROTATE = 'rotation'
@@ -94,7 +93,7 @@ __all__ = [
94
93
  ]
95
94
 
96
95
 
97
- class Module(object):
96
+ class Object:
98
97
  """
99
98
  The Module class for the whole ecosystem.
100
99
 
@@ -118,43 +117,9 @@ class Module(object):
118
117
  # the excluded nodes
119
118
  _invisible_nodes: Tuple[str, ...] = ()
120
119
 
121
- # # the supported computing modes
122
- # supported_modes: Optional[Sequence[Mode]] = None
123
-
124
- def __init__(self, name: str = None, mode: Mode = None):
125
- super().__init__()
126
-
127
- # check whether the object has a unique name.
128
- self._name = unique_name(self=self, name=name)
129
-
130
- # mode setting
131
- self._mode = None
132
- self.mode = mode if mode is not None else environ.get('mode')
133
-
134
120
  def __repr__(self):
135
121
  return f'{self.__class__.__name__}'
136
122
 
137
- @property
138
- def name(self):
139
- """Name of the model."""
140
- return self._name
141
-
142
- @name.setter
143
- def name(self, name: str = None):
144
- raise AttributeError('The name of the model is read-only.')
145
-
146
- @property
147
- def mode(self):
148
- """Mode of the model, which is useful to control the multiple behaviors of the model."""
149
- return self._mode
150
-
151
- @mode.setter
152
- def mode(self, value):
153
- if not isinstance(value, Mode):
154
- raise ValueError(f'Must be instance of {Mode.__name__}, '
155
- f'but we got {type(value)}: {value}')
156
- self._mode = value
157
-
158
123
  def states(
159
124
  self,
160
125
  method: str = 'absolute',
@@ -238,30 +203,6 @@ class Module(object):
238
203
  nodes = nodes.unique()
239
204
  return nodes
240
205
 
241
- def update(self, *args, **kwargs):
242
- """
243
- The function to specify the updating rule.
244
- """
245
- raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
246
- f'implement "update" function.')
247
-
248
- def __call__(self, *args, **kwargs):
249
- return self.update(*args, **kwargs)
250
-
251
- def __rrshift__(self, other):
252
- """
253
- Support using right shift operator to call modules.
254
-
255
- Examples
256
- --------
257
-
258
- >>> import brainstate as bst
259
- >>> x = bst.random.rand((10, 10))
260
- >>> l = bst.nn.Activation(jax.numpy.tanh)
261
- >>> y = x >> l
262
- """
263
- return self.__call__(other)
264
-
265
206
  def init_state(self, *args, **kwargs):
266
207
  """
267
208
  State initialization function.
@@ -289,6 +230,23 @@ class Module(object):
289
230
  missing_keys = list(keys2 - keys1)
290
231
  return unexpected_keys, missing_keys
291
232
 
233
+ def __treescope_repr__(self, path, subtree_renderer):
234
+ import treescope # type: ignore[import-not-found,import-untyped]
235
+ children = {}
236
+ for name, value in vars(self).items():
237
+ if name.startswith('_'):
238
+ continue
239
+ children[name] = value
240
+ return treescope.repr_lib.render_object_constructor(
241
+ object_type=type(self),
242
+ attributes=children,
243
+ path=path,
244
+ subtree_renderer=subtree_renderer,
245
+ color=treescope.formatting_util.color_from_string(
246
+ type(self).__qualname__
247
+ )
248
+ )
249
+
292
250
 
293
251
  def _find_nodes(self, method: str = 'absolute', level=-1, include_self=True, _lid=0, _edges=None) -> DictManager:
294
252
  if _edges is None:
@@ -374,6 +332,83 @@ def _add_node_relative(self, k, v, _paths, gather, nodes):
374
332
  nodes.append((k, v))
375
333
 
376
334
 
335
+ class Module(Object):
336
+ """
337
+ The Module class for the whole ecosystem.
338
+
339
+ The ``Module`` is the base class for all the objects in the ecosystem. It
340
+ provides the basic functionalities for the objects, including:
341
+
342
+ - ``states()``: Collect all states in this node and the children nodes.
343
+ - ``nodes()``: Collect all children nodes.
344
+ - ``update()``: The function to specify the updating rule.
345
+ - ``init_state()``: State initialization function.
346
+ - ``save_state()``: Save states as a dictionary.
347
+ - ``load_state()``: Load states from the external objects.
348
+
349
+ """
350
+
351
+ __module__ = 'brainstate'
352
+
353
+ def __init__(self, name: str = None, mode: Mode = None):
354
+ super().__init__()
355
+
356
+ # check whether the object has a unique name.
357
+ self._name = unique_name(self=self, name=name)
358
+
359
+ # mode setting
360
+ self._mode = None
361
+ self.mode = mode if mode is not None else environ.get('mode')
362
+
363
+ def __repr__(self):
364
+ return f'{self.__class__.__name__}'
365
+
366
+ @property
367
+ def name(self):
368
+ """Name of the model."""
369
+ return self._name
370
+
371
+ @name.setter
372
+ def name(self, name: str = None):
373
+ raise AttributeError('The name of the model is read-only.')
374
+
375
+ @property
376
+ def mode(self):
377
+ """Mode of the model, which is useful to control the multiple behaviors of the model."""
378
+ return self._mode
379
+
380
+ @mode.setter
381
+ def mode(self, value):
382
+ if not isinstance(value, Mode):
383
+ raise ValueError(f'Must be instance of {Mode.__name__}, '
384
+ f'but we got {type(value)}: {value}')
385
+ self._mode = value
386
+
387
+ def update(self, *args, **kwargs):
388
+ """
389
+ The function to specify the updating rule.
390
+ """
391
+ raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
392
+ f'implement "update" function.')
393
+
394
+ def __call__(self, *args, **kwargs):
395
+ return self.update(*args, **kwargs)
396
+
397
+ def __rrshift__(self, other):
398
+ """
399
+ Support using right shift operator to call modules.
400
+
401
+ Examples
402
+ --------
403
+
404
+ >>> import brainstate as bst
405
+ >>> x = bst.random.rand((10, 10))
406
+ >>> l = bst.nn.Dropout(0.5)
407
+ >>> y = x >> l
408
+ """
409
+ return self.__call__(other)
410
+
411
+
377
412
  class Projection(Module):
378
413
  """
379
414
  Base class to model synaptic projections.
brainstate/_state.py CHANGED
@@ -22,8 +22,8 @@ import numpy as np
22
22
  from jax.api_util import shaped_abstractify
23
23
  from jax.extend import source_info_util
24
24
 
25
- from .typing import ArrayLike, PyTree
26
- from .util import DictManager
25
+ from brainstate.typing import ArrayLike, PyTree
26
+ from brainstate.util import DictManager
27
27
 
28
28
  __all__ = [
29
29
  'State', 'ShortTermState', 'LongTermState', 'ParamState',
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ __all__ = [
18
+ 'display',
19
+ ]
20
+
21
+
22
+ import importlib.util
23
+
24
+ treescope_installed = importlib.util.find_spec('treescope') is not None
25
+ try:
26
+ from IPython import get_ipython
27
+
28
+ in_ipython = get_ipython() is not None
29
+ except ImportError:
30
+ in_ipython = False
31
+
32
+
33
+ def display(*args):
34
+ """Display the given objects using the Treescope pretty-printer.
35
+
36
+ If treescope is not installed or the code is not running in IPython,
37
+ ``display`` will print the objects instead.
38
+ """
39
+ if not treescope_installed or not in_ipython:
40
+ for x in args:
41
+ print(x)
42
+ return
43
+
44
+ import treescope # type: ignore[import-not-found,import-untyped]
45
+
46
+ for x in args:
47
+ treescope.display(x, ignore_exceptions=True, autovisualize=True)
brainstate/environ.py CHANGED
@@ -16,12 +16,14 @@ from .mixin import Mode
16
16
  from .util import MemScaling, IdMemScaling
17
17
 
18
18
  __all__ = [
19
- 'set', 'context', 'get', 'all',
20
- 'set_host_device_count', 'set_platform',
21
- 'get_host_device_count', 'get_platform',
22
- 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
23
- 'tolerance', 'register_default_behavior',
19
+ # functions for environment settings
20
+ 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
21
+ # functions for getting default behaviors
22
+ 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
23
+ # functions for default data types
24
24
  'dftype', 'ditype', 'dutype', 'dctype',
25
+ # others
26
+ 'tolerance', 'register_default_behavior',
25
27
  ]
26
28
 
27
29
  # Default, there are several shared arguments in the global context.
@@ -57,8 +59,9 @@ def context(**kwargs):
57
59
 
58
60
  """
59
61
  if 'platform' in kwargs:
60
- raise ValueError('Cannot set platform in environment context. '
61
- 'Please use set_platform() or set() for the global setting.')
62
+ raise ValueError('\n'
63
+ 'Cannot set platform in "context" environment. \n'
64
+ 'You should set platform in the global environment by "set_platform()" or "set()".')
62
65
  if 'host_device_count' in kwargs:
63
66
  raise ValueError('Cannot set host_device_count in environment context. '
64
67
  'Please use set_host_device_count() or set() for the global setting.')
@@ -135,6 +138,11 @@ def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
135
138
  def all() -> dict:
136
139
  """
137
140
  Get all the current default computation environment.
141
+
142
+ Returns
143
+ -------
144
+ r: dict
145
+ The current default computation environment.
138
146
  """
139
147
  r = dict()
140
148
  for k, v in _environment_contexts.items():
@@ -227,6 +235,8 @@ def set(
227
235
  """
228
236
  Set the global default computation environment.
229
237
 
238
+
239
+
230
240
  Args:
231
241
  platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
232
242
  host_device_count: int. The number of host devices.
@@ -290,6 +300,12 @@ def set_platform(platform: str):
290
300
  """
291
301
  Changes platform to CPU, GPU, or TPU. This utility only takes
292
302
  effect at the beginning of your program.
303
+
304
+ Args:
305
+ platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
306
+
307
+ Raises:
308
+ ValueError: If the platform is not in ['cpu', 'gpu', 'tpu'].
293
309
  """
294
310
  assert platform in ['cpu', 'gpu', 'tpu']
295
311
  config.update("jax_platform_name", platform)
@@ -369,6 +385,23 @@ def _get_complex(precision: int):
369
385
  def dftype() -> DTypeLike:
370
386
  """
371
387
  Default floating data type.
388
+
389
+ This function returns the default floating data type based on the current precision.
390
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
391
+ you can use this function to get the default floating data type, and create the data by using ``dtype=dftype()``.
392
+
393
+ For example, if the precision is set to 32, the default floating data type is ``np.float32``.
394
+
395
+ >>> import brainstate as bst
396
+ >>> import numpy as np
397
+ >>> with bst.environ.context(precision=32):
398
+ ... a = np.zeros(1, dtype=bst.environ.dftype())
399
+ >>> print(a.dtype)
400
+
401
+ Returns
402
+ -------
403
+ float_dtype: DTypeLike
404
+ The default floating data type.
372
405
  """
373
406
  return _get_float(get_precision())
374
407
 
@@ -376,6 +409,24 @@ def dftype() -> DTypeLike:
376
409
  def ditype() -> DTypeLike:
377
410
  """
378
411
  Default integer data type.
412
+
413
+ This function returns the default integer data type based on the current precision.
414
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
415
+ you can use this function to get the default integer data type, and create the data by using ``dtype=ditype()``.
416
+
417
+ For example, if the precision is set to 32, the default integer data type is ``np.int32``.
418
+
419
+ >>> import brainstate as bst
420
+ >>> import numpy as np
421
+ >>> with bst.environ.context(precision=32):
422
+ ... a = np.zeros(1, dtype=bst.environ.ditype())
423
+ >>> print(a.dtype)
424
+ int32
425
+
426
+ Returns
427
+ -------
428
+ int_dtype: DTypeLike
429
+ The default integer data type.
379
430
  """
380
431
  return _get_int(get_precision())
381
432
 
@@ -383,6 +434,25 @@ def ditype() -> DTypeLike:
383
434
  def dutype() -> DTypeLike:
384
435
  """
385
436
  Default unsigned integer data type.
437
+
438
+ This function returns the default unsigned integer data type based on the current precision.
439
+ If you want the data dtype is changed with the setting of the precision
440
+ by ``brainstate.environ.set(precision)``, you can use this function to get the default
441
+ unsigned integer data type, and create the data by using ``dtype=dutype()``.
442
+
443
+ For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
444
+
445
+ >>> import brainstate as bst
446
+ >>> import numpy as np
447
+ >>> with bst.environ.context(precision=32):
448
+ ... a = np.zeros(1, dtype=bst.environ.dutype())
449
+ >>> print(a.dtype)
450
+ uint32
451
+
452
+ Returns
453
+ -------
454
+ uint_dtype: DTypeLike
455
+ The default unsigned integer data type.
386
456
  """
387
457
  return _get_uint(get_precision())
388
458
 
@@ -390,6 +460,24 @@ def dutype() -> DTypeLike:
390
460
  def dctype() -> DTypeLike:
391
461
  """
392
462
  Default complex data type.
463
+
464
+ This function returns the default complex data type based on the current precision.
465
+ If you want the data dtype is changed with the setting of the precision by ``brainstate.environ.set(precision)``,
466
+ you can use this function to get the default complex data type, and create the data by using ``dtype=dctype()``.
467
+
468
+ For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
469
+
470
+ >>> import brainstate as bst
471
+ >>> import numpy as np
472
+ >>> with bst.environ.context(precision=32):
473
+ ... a = np.zeros(1, dtype=bst.environ.dctype())
474
+ >>> print(a.dtype)
475
+ complex64
476
+
477
+ Returns
478
+ -------
479
+ complex_dtype: DTypeLike
480
+ The default complex data type.
393
481
  """
394
482
  return _get_complex(get_precision())
395
483
 
@@ -405,7 +493,27 @@ def tolerance():
405
493
 
406
494
  def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bool = False):
407
495
  """
408
- Register a default behavior for a specific key.
496
+ Register a default behavior for a specific global key parameter.
497
+
498
+ For example, you can register a default behavior for the key 'dt' by::
499
+
500
+ >>> import brainstate as bst
501
+ >>> def dt_behavior(dt):
502
+ ... print(f'Set the default dt to {dt}.')
503
+ ...
504
+ >>> bst.environ.register_default_behavior('dt', dt_behavior)
505
+
506
+ Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
507
+ `dt_behavior` will be called with
508
+ `dt_behavior(0.1)`.
509
+
510
+ >>> bst.environ.set(dt=0.1)
511
+ Set the default dt to 0.1.
512
+ >>> with bst.environ.context(dt=0.2):
513
+ ... pass
514
+ Set the default dt to 0.2.
515
+ Set the default dt to 0.1.
516
+
409
517
 
410
518
  Args:
411
519
  key: str. The key to register.
@@ -421,4 +529,3 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo
421
529
 
422
530
 
423
531
  set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling())
424
-
@@ -0,0 +1,56 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate as bst
22
+
23
+
24
+ class TestEnviron(unittest.TestCase):
25
+ def test_precision(self):
26
+ with bst.environ.context(precision=64):
27
+ a = bst.random.randn(1)
28
+ self.assertEqual(a.dtype, jnp.float64)
29
+
30
+ with bst.environ.context(precision=32):
31
+ a = bst.random.randn(1)
32
+ self.assertEqual(a.dtype, jnp.float32)
33
+
34
+ with bst.environ.context(precision=16):
35
+ a = bst.random.randn(1)
36
+ self.assertEqual(a.dtype, jnp.bfloat16)
37
+
38
+ def test_platform(self):
39
+ with self.assertRaises(ValueError):
40
+ with bst.environ.context(platform='cpu'):
41
+ a = bst.random.randn(1)
42
+ self.assertEqual(a.device(), 'cpu')
43
+
44
+ def test_register_default_behavior(self):
45
+ dt_ = 0.1
46
+
47
+ def dt_behavior(dt):
48
+ nonlocal dt_
49
+ dt_ = dt
50
+ print(f'dt: {dt}')
51
+
52
+ bst.environ.register_default_behavior('dt', dt_behavior)
53
+
54
+ with bst.environ.context(dt=0.2):
55
+ self.assertEqual(dt_, 0.2)
56
+ self.assertEqual(dt_, 0.1)