brainstate 0.1.0.post20250102__py2.py3-none-any.whl → 0.1.0.post20250104__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/environ.py +30 -33
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/RECORD +6 -6
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -31,13 +31,12 @@ from jax import config, devices, numpy as jnp
|
|
31
31
|
from jax.typing import DTypeLike
|
32
32
|
|
33
33
|
from .mixin import Mode
|
34
|
-
from .util import MemScaling
|
35
34
|
|
36
35
|
__all__ = [
|
37
36
|
# functions for environment settings
|
38
37
|
'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
|
39
38
|
# functions for getting default behaviors
|
40
|
-
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', '
|
39
|
+
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
|
41
40
|
# functions for default data types
|
42
41
|
'dftype', 'ditype', 'dutype', 'dctype',
|
43
42
|
# others
|
@@ -94,7 +93,7 @@ def context(**kwargs):
|
|
94
93
|
'Please use set_host_device_count() or set() for the global setting.')
|
95
94
|
|
96
95
|
if 'precision' in kwargs:
|
97
|
-
last_precision =
|
96
|
+
last_precision = _get_precision()
|
98
97
|
_set_jax_precision(kwargs['precision'])
|
99
98
|
|
100
99
|
try:
|
@@ -203,17 +202,6 @@ def get_mode() -> Mode:
|
|
203
202
|
return get('mode')
|
204
203
|
|
205
204
|
|
206
|
-
def get_mem_scaling() -> MemScaling:
|
207
|
-
"""Get the default computing membrane_scaling.
|
208
|
-
|
209
|
-
Returns
|
210
|
-
-------
|
211
|
-
membrane_scaling: MemScaling
|
212
|
-
The default computing membrane_scaling.
|
213
|
-
"""
|
214
|
-
return get('mem_scaling')
|
215
|
-
|
216
|
-
|
217
205
|
def get_platform() -> str:
|
218
206
|
"""Get the computing platform.
|
219
207
|
|
@@ -239,7 +227,7 @@ def get_host_device_count():
|
|
239
227
|
return int(match.group(1)) if match else 1
|
240
228
|
|
241
229
|
|
242
|
-
def
|
230
|
+
def _get_precision() -> int | str:
|
243
231
|
"""
|
244
232
|
Get the default precision.
|
245
233
|
|
@@ -251,11 +239,23 @@ def get_precision() -> int:
|
|
251
239
|
return get('precision')
|
252
240
|
|
253
241
|
|
242
|
+
def get_precision() -> int:
|
243
|
+
"""
|
244
|
+
Get the default precision.
|
245
|
+
|
246
|
+
Returns
|
247
|
+
-------
|
248
|
+
precision: int
|
249
|
+
The default precision.
|
250
|
+
"""
|
251
|
+
precision = get('precision')
|
252
|
+
return 16 if precision == 'bf16' else precision
|
253
|
+
|
254
|
+
|
254
255
|
def set(
|
255
256
|
platform: str = None,
|
256
257
|
host_device_count: int = None,
|
257
|
-
|
258
|
-
precision: int = None,
|
258
|
+
precision: int | str = None,
|
259
259
|
mode: Mode = None,
|
260
260
|
**kwargs
|
261
261
|
):
|
@@ -267,8 +267,7 @@ def set(
|
|
267
267
|
Args:
|
268
268
|
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
269
269
|
host_device_count: int. The number of host devices.
|
270
|
-
|
271
|
-
precision: int. The default precision.
|
270
|
+
precision: int, str. The default precision.
|
272
271
|
mode: Mode. The computing mode.
|
273
272
|
**kwargs: dict. Other environment settings.
|
274
273
|
"""
|
@@ -276,9 +275,6 @@ def set(
|
|
276
275
|
set_platform(platform)
|
277
276
|
if host_device_count is not None:
|
278
277
|
set_host_device_count(host_device_count)
|
279
|
-
if mem_scaling is not None:
|
280
|
-
assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
|
281
|
-
kwargs['mem_scaling'] = mem_scaling
|
282
278
|
if precision is not None:
|
283
279
|
_set_jax_precision(precision)
|
284
280
|
kwargs['precision'] = precision
|
@@ -342,14 +338,14 @@ def set_platform(platform: str):
|
|
342
338
|
DFAULT.functions['platform'](platform)
|
343
339
|
|
344
340
|
|
345
|
-
def _set_jax_precision(precision: int):
|
341
|
+
def _set_jax_precision(precision: int | str):
|
346
342
|
"""
|
347
343
|
Set the default precision.
|
348
344
|
|
349
345
|
Args:
|
350
346
|
precision: int. The default precision.
|
351
347
|
"""
|
352
|
-
assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
|
348
|
+
assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
353
349
|
if precision == 64:
|
354
350
|
config.update("jax_enable_x64", True)
|
355
351
|
else:
|
@@ -362,7 +358,7 @@ def _get_uint(precision: int):
|
|
362
358
|
return np.uint64
|
363
359
|
elif precision == 32:
|
364
360
|
return np.uint32
|
365
|
-
elif precision
|
361
|
+
elif precision in [16, 'bf16']:
|
366
362
|
return np.uint16
|
367
363
|
elif precision == 8:
|
368
364
|
return np.uint8
|
@@ -376,7 +372,7 @@ def _get_int(precision: int):
|
|
376
372
|
return np.int64
|
377
373
|
elif precision == 32:
|
378
374
|
return np.int32
|
379
|
-
elif precision
|
375
|
+
elif precision in [16, 'bf16']:
|
380
376
|
return np.int16
|
381
377
|
elif precision == 8:
|
382
378
|
return np.int8
|
@@ -391,8 +387,9 @@ def _get_float(precision: int):
|
|
391
387
|
elif precision == 32:
|
392
388
|
return np.float32
|
393
389
|
elif precision == 16:
|
390
|
+
return np.float16
|
391
|
+
elif precision == 'bf16':
|
394
392
|
return jnp.bfloat16
|
395
|
-
# return np.float16
|
396
393
|
else:
|
397
394
|
raise ValueError(f'Unsupported precision: {precision}')
|
398
395
|
|
@@ -403,8 +400,8 @@ def _get_complex(precision: int):
|
|
403
400
|
return np.complex128
|
404
401
|
elif precision == 32:
|
405
402
|
return np.complex64
|
406
|
-
elif precision
|
407
|
-
return np.
|
403
|
+
elif precision in [16, 'bf16']:
|
404
|
+
return np.complex64
|
408
405
|
else:
|
409
406
|
raise ValueError(f'Unsupported precision: {precision}')
|
410
407
|
|
@@ -430,7 +427,7 @@ def dftype() -> DTypeLike:
|
|
430
427
|
float_dtype: DTypeLike
|
431
428
|
The default floating data type.
|
432
429
|
"""
|
433
|
-
return _get_float(
|
430
|
+
return _get_float(_get_precision())
|
434
431
|
|
435
432
|
|
436
433
|
def ditype() -> DTypeLike:
|
@@ -455,7 +452,7 @@ def ditype() -> DTypeLike:
|
|
455
452
|
int_dtype: DTypeLike
|
456
453
|
The default integer data type.
|
457
454
|
"""
|
458
|
-
return _get_int(
|
455
|
+
return _get_int(_get_precision())
|
459
456
|
|
460
457
|
|
461
458
|
def dutype() -> DTypeLike:
|
@@ -481,7 +478,7 @@ def dutype() -> DTypeLike:
|
|
481
478
|
uint_dtype: DTypeLike
|
482
479
|
The default unsigned integer data type.
|
483
480
|
"""
|
484
|
-
return _get_uint(
|
481
|
+
return _get_uint(_get_precision())
|
485
482
|
|
486
483
|
|
487
484
|
def dctype() -> DTypeLike:
|
@@ -506,7 +503,7 @@ def dctype() -> DTypeLike:
|
|
506
503
|
complex_dtype: DTypeLike
|
507
504
|
The default complex data type.
|
508
505
|
"""
|
509
|
-
return _get_complex(
|
506
|
+
return _get_complex(_get_precision())
|
510
507
|
|
511
508
|
|
512
509
|
def tolerance():
|
{brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250104
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -2,7 +2,7 @@ brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
|
|
2
2
|
brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
|
3
3
|
brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
|
-
brainstate/environ.py,sha256=
|
5
|
+
brainstate/environ.py,sha256=wCVlSjavJ9OXc1STOucg-VfXeK9KE443a1SDhkK9lA8,17270
|
6
6
|
brainstate/environ_test.py,sha256=jXX3nR1CO74aow5YqfqSd73isj9MWgHQxrwSsEjTDY8,1901
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
|
|
137
137
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
138
138
|
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
139
139
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
144
|
-
brainstate-0.1.0.
|
140
|
+
brainstate-0.1.0.post20250104.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
141
|
+
brainstate-0.1.0.post20250104.dist-info/METADATA,sha256=pGOGZBCF8q5La6-DIO1hdldslPnnt4bfjVnOVWgHRU4,3533
|
142
|
+
brainstate-0.1.0.post20250104.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
143
|
+
brainstate-0.1.0.post20250104.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
144
|
+
brainstate-0.1.0.post20250104.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250104.dist-info}/top_level.txt
RENAMED
File without changes
|