brainstate 0.1.0.post20250101__py2.py3-none-any.whl → 0.1.0.post20250104__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/environ.py CHANGED
@@ -31,13 +31,12 @@ from jax import config, devices, numpy as jnp
31
31
  from jax.typing import DTypeLike
32
32
 
33
33
  from .mixin import Mode
34
- from .util import MemScaling
35
34
 
36
35
  __all__ = [
37
36
  # functions for environment settings
38
37
  'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
39
38
  # functions for getting default behaviors
40
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
39
+ 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
41
40
  # functions for default data types
42
41
  'dftype', 'ditype', 'dutype', 'dctype',
43
42
  # others
@@ -94,7 +93,7 @@ def context(**kwargs):
94
93
  'Please use set_host_device_count() or set() for the global setting.')
95
94
 
96
95
  if 'precision' in kwargs:
97
- last_precision = get_precision()
96
+ last_precision = _get_precision()
98
97
  _set_jax_precision(kwargs['precision'])
99
98
 
100
99
  try:
@@ -203,17 +202,6 @@ def get_mode() -> Mode:
203
202
  return get('mode')
204
203
 
205
204
 
206
- def get_mem_scaling() -> MemScaling:
207
- """Get the default computing membrane_scaling.
208
-
209
- Returns
210
- -------
211
- membrane_scaling: MemScaling
212
- The default computing membrane_scaling.
213
- """
214
- return get('mem_scaling')
215
-
216
-
217
205
  def get_platform() -> str:
218
206
  """Get the computing platform.
219
207
 
@@ -239,7 +227,7 @@ def get_host_device_count():
239
227
  return int(match.group(1)) if match else 1
240
228
 
241
229
 
242
- def get_precision() -> int:
230
+ def _get_precision() -> int | str:
243
231
  """
244
232
  Get the default precision.
245
233
 
@@ -251,11 +239,23 @@ def get_precision() -> int:
251
239
  return get('precision')
252
240
 
253
241
 
242
+ def get_precision() -> int:
243
+ """
244
+ Get the default precision.
245
+
246
+ Returns
247
+ -------
248
+ precision: int
249
+ The default precision.
250
+ """
251
+ precision = get('precision')
252
+ return 16 if precision == 'bf16' else precision
253
+
254
+
254
255
  def set(
255
256
  platform: str = None,
256
257
  host_device_count: int = None,
257
- mem_scaling: MemScaling = None,
258
- precision: int = None,
258
+ precision: int | str = None,
259
259
  mode: Mode = None,
260
260
  **kwargs
261
261
  ):
@@ -267,8 +267,7 @@ def set(
267
267
  Args:
268
268
  platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
269
269
  host_device_count: int. The number of host devices.
270
- mem_scaling: MemScaling. The membrane scaling.
271
- precision: int. The default precision.
270
+ precision: int, str. The default precision.
272
271
  mode: Mode. The computing mode.
273
272
  **kwargs: dict. Other environment settings.
274
273
  """
@@ -276,9 +275,6 @@ def set(
276
275
  set_platform(platform)
277
276
  if host_device_count is not None:
278
277
  set_host_device_count(host_device_count)
279
- if mem_scaling is not None:
280
- assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
281
- kwargs['mem_scaling'] = mem_scaling
282
278
  if precision is not None:
283
279
  _set_jax_precision(precision)
284
280
  kwargs['precision'] = precision
@@ -342,14 +338,14 @@ def set_platform(platform: str):
342
338
  DFAULT.functions['platform'](platform)
343
339
 
344
340
 
345
- def _set_jax_precision(precision: int):
341
+ def _set_jax_precision(precision: int | str):
346
342
  """
347
343
  Set the default precision.
348
344
 
349
345
  Args:
350
346
  precision: int. The default precision.
351
347
  """
352
- assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
348
+ assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
353
349
  if precision == 64:
354
350
  config.update("jax_enable_x64", True)
355
351
  else:
@@ -362,7 +358,7 @@ def _get_uint(precision: int):
362
358
  return np.uint64
363
359
  elif precision == 32:
364
360
  return np.uint32
365
- elif precision == 16:
361
+ elif precision in [16, 'bf16']:
366
362
  return np.uint16
367
363
  elif precision == 8:
368
364
  return np.uint8
@@ -376,7 +372,7 @@ def _get_int(precision: int):
376
372
  return np.int64
377
373
  elif precision == 32:
378
374
  return np.int32
379
- elif precision == 16:
375
+ elif precision in [16, 'bf16']:
380
376
  return np.int16
381
377
  elif precision == 8:
382
378
  return np.int8
@@ -391,8 +387,9 @@ def _get_float(precision: int):
391
387
  elif precision == 32:
392
388
  return np.float32
393
389
  elif precision == 16:
390
+ return np.float16
391
+ elif precision == 'bf16':
394
392
  return jnp.bfloat16
395
- # return np.float16
396
393
  else:
397
394
  raise ValueError(f'Unsupported precision: {precision}')
398
395
 
@@ -403,8 +400,8 @@ def _get_complex(precision: int):
403
400
  return np.complex128
404
401
  elif precision == 32:
405
402
  return np.complex64
406
- elif precision == 16:
407
- return np.complex32
403
+ elif precision in [16, 'bf16']:
404
+ return np.complex64
408
405
  else:
409
406
  raise ValueError(f'Unsupported precision: {precision}')
410
407
 
@@ -430,7 +427,7 @@ def dftype() -> DTypeLike:
430
427
  float_dtype: DTypeLike
431
428
  The default floating data type.
432
429
  """
433
- return _get_float(get_precision())
430
+ return _get_float(_get_precision())
434
431
 
435
432
 
436
433
  def ditype() -> DTypeLike:
@@ -455,7 +452,7 @@ def ditype() -> DTypeLike:
455
452
  int_dtype: DTypeLike
456
453
  The default integer data type.
457
454
  """
458
- return _get_int(get_precision())
455
+ return _get_int(_get_precision())
459
456
 
460
457
 
461
458
  def dutype() -> DTypeLike:
@@ -481,7 +478,7 @@ def dutype() -> DTypeLike:
481
478
  uint_dtype: DTypeLike
482
479
  The default unsigned integer data type.
483
480
  """
484
- return _get_uint(get_precision())
481
+ return _get_uint(_get_precision())
485
482
 
486
483
 
487
484
  def dctype() -> DTypeLike:
@@ -506,7 +503,7 @@ def dctype() -> DTypeLike:
506
503
  complex_dtype: DTypeLike
507
504
  The default complex data type.
508
505
  """
509
- return _get_complex(get_precision())
506
+ return _get_complex(_get_precision())
510
507
 
511
508
 
512
509
  def tolerance():
@@ -29,7 +29,7 @@ from brainstate.nn._exp_euler import exp_euler_step
29
29
  from brainstate.typing import ArrayLike, Size
30
30
 
31
31
  __all__ = [
32
- 'Synapse', 'Expon', 'STP', 'STD', 'AMPA', 'GABAa',
32
+ 'Synapse', 'Expon', 'DualExpon', 'Alpha', 'STP', 'STD', 'AMPA', 'GABAa',
33
33
  ]
34
34
 
35
35
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250101
3
+ Version: 0.1.0.post20250104
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -2,7 +2,7 @@ brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
2
2
  brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
3
3
  brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
- brainstate/environ.py,sha256=G6r_rqfbofRbjFFalRu_DHaL7ruFTeLRXBQDXM6P-tQ,17477
5
+ brainstate/environ.py,sha256=wCVlSjavJ9OXc1STOucg-VfXeK9KE443a1SDhkK9lA8,17270
6
6
  brainstate/environ_test.py,sha256=jXX3nR1CO74aow5YqfqSd73isj9MWgHQxrwSsEjTDY8,1901
7
7
  brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
8
8
  brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
@@ -81,7 +81,7 @@ brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,1473
81
81
  brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
82
82
  brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2K8k5FAcf3Pa5N8,10927
83
83
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
84
- brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=41P3qJf8fi4J7dZdDnjChJF6lYJjFAOkgy9aE3FReY4,15247
84
+ brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
85
85
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
86
86
  brainstate/nn/_dyn_impl/_inputs.py,sha256=pkcAVt_o5kQF_BGCTZZ-NUQpHgjlFHHPwtYC0fJkAA0,9099
87
87
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
137
137
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
138
138
  brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
139
139
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
140
- brainstate-0.1.0.post20250101.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
- brainstate-0.1.0.post20250101.dist-info/METADATA,sha256=QFsDpwFlj0QnoV3srIVjxTiWxJEmVROOlhLRM3u-B44,3533
142
- brainstate-0.1.0.post20250101.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
- brainstate-0.1.0.post20250101.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
- brainstate-0.1.0.post20250101.dist-info/RECORD,,
140
+ brainstate-0.1.0.post20250104.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
+ brainstate-0.1.0.post20250104.dist-info/METADATA,sha256=pGOGZBCF8q5La6-DIO1hdldslPnnt4bfjVnOVWgHRU4,3533
142
+ brainstate-0.1.0.post20250104.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
+ brainstate-0.1.0.post20250104.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
+ brainstate-0.1.0.post20250104.dist-info/RECORD,,