brainstate 0.1.0.post20250102__py2.py3-none-any.whl → 0.1.0.post20250105__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_unvmap.py +2 -1
- brainstate/environ.py +52 -45
- brainstate/environ_test.py +4 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +9 -2
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/RECORD +14 -14
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/top_level.txt +0 -0
@@ -206,7 +206,7 @@ def scan(
|
|
206
206
|
|
207
207
|
# evaluate jaxpr, get all states #
|
208
208
|
# ------------------------------ #
|
209
|
-
xs_avals = [jax.core.
|
209
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
210
210
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
211
211
|
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
212
212
|
state_trace = stateful_fun.get_state_trace()
|
@@ -302,7 +302,7 @@ def checkpointed_scan(
|
|
302
302
|
pbar_runner = None
|
303
303
|
|
304
304
|
# evaluate jaxpr
|
305
|
-
xs_avals = [jax.core.
|
305
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
306
306
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
307
307
|
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
308
308
|
state_trace = stateful_fun.get_state_trace()
|
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import jax
|
21
|
+
import jax.extend as je
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import pytest
|
23
24
|
|
@@ -84,7 +85,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
84
85
|
print(jaxpr)
|
85
86
|
jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
|
86
87
|
print(jaxpr)
|
87
|
-
self.assertTrue(jnp.allclose(
|
88
|
+
self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
88
89
|
f3(jnp.zeros(1))))
|
89
90
|
|
90
91
|
def test_compar_jax_make_jaxpr2(self):
|
brainstate/compile/_unvmap.py
CHANGED
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import jax
|
18
18
|
import jax.core
|
19
|
+
import jax.extend as je
|
19
20
|
import jax.interpreters.batching as batching
|
20
21
|
import jax.interpreters.mlir as mlir
|
21
22
|
import jax.numpy as jnp
|
@@ -43,7 +44,7 @@ def unvmap(x, op: str = 'any'):
|
|
43
44
|
|
44
45
|
# unvmap_all
|
45
46
|
|
46
|
-
unvmap_all_p =
|
47
|
+
unvmap_all_p = je.core.Primitive("unvmap_all")
|
47
48
|
|
48
49
|
|
49
50
|
def unvmap_all(x):
|
brainstate/environ.py
CHANGED
@@ -31,13 +31,12 @@ from jax import config, devices, numpy as jnp
|
|
31
31
|
from jax.typing import DTypeLike
|
32
32
|
|
33
33
|
from .mixin import Mode
|
34
|
-
from .util import MemScaling
|
35
34
|
|
36
35
|
__all__ = [
|
37
36
|
# functions for environment settings
|
38
37
|
'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform',
|
39
38
|
# functions for getting default behaviors
|
40
|
-
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', '
|
39
|
+
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_precision',
|
41
40
|
# functions for default data types
|
42
41
|
'dftype', 'ditype', 'dutype', 'dctype',
|
43
42
|
# others
|
@@ -94,7 +93,7 @@ def context(**kwargs):
|
|
94
93
|
'Please use set_host_device_count() or set() for the global setting.')
|
95
94
|
|
96
95
|
if 'precision' in kwargs:
|
97
|
-
last_precision =
|
96
|
+
last_precision = _get_precision()
|
98
97
|
_set_jax_precision(kwargs['precision'])
|
99
98
|
|
100
99
|
try:
|
@@ -203,17 +202,6 @@ def get_mode() -> Mode:
|
|
203
202
|
return get('mode')
|
204
203
|
|
205
204
|
|
206
|
-
def get_mem_scaling() -> MemScaling:
|
207
|
-
"""Get the default computing membrane_scaling.
|
208
|
-
|
209
|
-
Returns
|
210
|
-
-------
|
211
|
-
membrane_scaling: MemScaling
|
212
|
-
The default computing membrane_scaling.
|
213
|
-
"""
|
214
|
-
return get('mem_scaling')
|
215
|
-
|
216
|
-
|
217
205
|
def get_platform() -> str:
|
218
206
|
"""Get the computing platform.
|
219
207
|
|
@@ -239,7 +227,7 @@ def get_host_device_count():
|
|
239
227
|
return int(match.group(1)) if match else 1
|
240
228
|
|
241
229
|
|
242
|
-
def
|
230
|
+
def _get_precision() -> int | str:
|
243
231
|
"""
|
244
232
|
Get the default precision.
|
245
233
|
|
@@ -251,11 +239,29 @@ def get_precision() -> int:
|
|
251
239
|
return get('precision')
|
252
240
|
|
253
241
|
|
242
|
+
def get_precision() -> int:
|
243
|
+
"""
|
244
|
+
Get the default precision.
|
245
|
+
|
246
|
+
Returns
|
247
|
+
-------
|
248
|
+
precision: int
|
249
|
+
The default precision.
|
250
|
+
"""
|
251
|
+
precision = get('precision')
|
252
|
+
if precision == 'bf16':
|
253
|
+
return 16
|
254
|
+
if isinstance(precision, int):
|
255
|
+
return precision
|
256
|
+
if isinstance(precision, str):
|
257
|
+
return int(precision)
|
258
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
259
|
+
|
260
|
+
|
254
261
|
def set(
|
255
262
|
platform: str = None,
|
256
263
|
host_device_count: int = None,
|
257
|
-
|
258
|
-
precision: int = None,
|
264
|
+
precision: int | str = None,
|
259
265
|
mode: Mode = None,
|
260
266
|
**kwargs
|
261
267
|
):
|
@@ -267,8 +273,7 @@ def set(
|
|
267
273
|
Args:
|
268
274
|
platform: str. The computing platform. Either 'cpu', 'gpu' or 'tpu'.
|
269
275
|
host_device_count: int. The number of host devices.
|
270
|
-
|
271
|
-
precision: int. The default precision.
|
276
|
+
precision: int, str. The default precision.
|
272
277
|
mode: Mode. The computing mode.
|
273
278
|
**kwargs: dict. Other environment settings.
|
274
279
|
"""
|
@@ -276,9 +281,6 @@ def set(
|
|
276
281
|
set_platform(platform)
|
277
282
|
if host_device_count is not None:
|
278
283
|
set_host_device_count(host_device_count)
|
279
|
-
if mem_scaling is not None:
|
280
|
-
assert isinstance(mem_scaling, MemScaling), 'mem_scaling must be a MemScaling instance.'
|
281
|
-
kwargs['mem_scaling'] = mem_scaling
|
282
284
|
if precision is not None:
|
283
285
|
_set_jax_precision(precision)
|
284
286
|
kwargs['precision'] = precision
|
@@ -342,15 +344,15 @@ def set_platform(platform: str):
|
|
342
344
|
DFAULT.functions['platform'](platform)
|
343
345
|
|
344
346
|
|
345
|
-
def _set_jax_precision(precision: int):
|
347
|
+
def _set_jax_precision(precision: int | str):
|
346
348
|
"""
|
347
349
|
Set the default precision.
|
348
350
|
|
349
351
|
Args:
|
350
352
|
precision: int. The default precision.
|
351
353
|
"""
|
352
|
-
assert precision in [64, 32, 16, 8], f'Precision must be in [64, 32, 16, 8]. But got {precision}.'
|
353
|
-
if precision
|
354
|
+
# assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
355
|
+
if precision in [64, '64']:
|
354
356
|
config.update("jax_enable_x64", True)
|
355
357
|
else:
|
356
358
|
config.update("jax_enable_x64", False)
|
@@ -358,13 +360,13 @@ def _set_jax_precision(precision: int):
|
|
358
360
|
|
359
361
|
@functools.lru_cache()
|
360
362
|
def _get_uint(precision: int):
|
361
|
-
if precision
|
363
|
+
if precision in [64, '64']:
|
362
364
|
return np.uint64
|
363
|
-
elif precision
|
365
|
+
elif precision in [32, '32']:
|
364
366
|
return np.uint32
|
365
|
-
elif precision
|
367
|
+
elif precision in [16, '16', 'bf16']:
|
366
368
|
return np.uint16
|
367
|
-
elif precision
|
369
|
+
elif precision in [8, '8']:
|
368
370
|
return np.uint8
|
369
371
|
else:
|
370
372
|
raise ValueError(f'Unsupported precision: {precision}')
|
@@ -372,13 +374,13 @@ def _get_uint(precision: int):
|
|
372
374
|
|
373
375
|
@functools.lru_cache()
|
374
376
|
def _get_int(precision: int):
|
375
|
-
if precision
|
377
|
+
if precision in [64, '64']:
|
376
378
|
return np.int64
|
377
|
-
elif precision
|
379
|
+
elif precision in [32, '32']:
|
378
380
|
return np.int32
|
379
|
-
elif precision
|
381
|
+
elif precision in [16, '16', 'bf16']:
|
380
382
|
return np.int16
|
381
|
-
elif precision
|
383
|
+
elif precision in [8, '8']:
|
382
384
|
return np.int8
|
383
385
|
else:
|
384
386
|
raise ValueError(f'Unsupported precision: {precision}')
|
@@ -386,25 +388,30 @@ def _get_int(precision: int):
|
|
386
388
|
|
387
389
|
@functools.lru_cache()
|
388
390
|
def _get_float(precision: int):
|
389
|
-
if precision
|
391
|
+
if precision in [64, '64']:
|
390
392
|
return np.float64
|
391
|
-
elif precision
|
393
|
+
elif precision in [32, '32']:
|
392
394
|
return np.float32
|
393
|
-
elif precision
|
395
|
+
elif precision in [16, '16']:
|
396
|
+
return np.float16
|
397
|
+
elif precision in ['bf16']:
|
394
398
|
return jnp.bfloat16
|
395
|
-
|
399
|
+
elif precision in [8, '8']:
|
400
|
+
return jnp.float8_e5m2
|
396
401
|
else:
|
397
402
|
raise ValueError(f'Unsupported precision: {precision}')
|
398
403
|
|
399
404
|
|
400
405
|
@functools.lru_cache()
|
401
406
|
def _get_complex(precision: int):
|
402
|
-
if precision == 64:
|
407
|
+
if precision == [64, '64']:
|
403
408
|
return np.complex128
|
404
|
-
elif precision == 32:
|
409
|
+
elif precision == [32, '32']:
|
410
|
+
return np.complex64
|
411
|
+
elif precision in [16, '16', 'bf16']:
|
412
|
+
return np.complex64
|
413
|
+
elif precision == [8, '8']:
|
405
414
|
return np.complex64
|
406
|
-
elif precision == 16:
|
407
|
-
return np.complex32
|
408
415
|
else:
|
409
416
|
raise ValueError(f'Unsupported precision: {precision}')
|
410
417
|
|
@@ -430,7 +437,7 @@ def dftype() -> DTypeLike:
|
|
430
437
|
float_dtype: DTypeLike
|
431
438
|
The default floating data type.
|
432
439
|
"""
|
433
|
-
return _get_float(
|
440
|
+
return _get_float(_get_precision())
|
434
441
|
|
435
442
|
|
436
443
|
def ditype() -> DTypeLike:
|
@@ -455,7 +462,7 @@ def ditype() -> DTypeLike:
|
|
455
462
|
int_dtype: DTypeLike
|
456
463
|
The default integer data type.
|
457
464
|
"""
|
458
|
-
return _get_int(
|
465
|
+
return _get_int(_get_precision())
|
459
466
|
|
460
467
|
|
461
468
|
def dutype() -> DTypeLike:
|
@@ -481,7 +488,7 @@ def dutype() -> DTypeLike:
|
|
481
488
|
uint_dtype: DTypeLike
|
482
489
|
The default unsigned integer data type.
|
483
490
|
"""
|
484
|
-
return _get_uint(
|
491
|
+
return _get_uint(_get_precision())
|
485
492
|
|
486
493
|
|
487
494
|
def dctype() -> DTypeLike:
|
@@ -506,7 +513,7 @@ def dctype() -> DTypeLike:
|
|
506
513
|
complex_dtype: DTypeLike
|
507
514
|
The default complex data type.
|
508
515
|
"""
|
509
|
-
return _get_complex(
|
516
|
+
return _get_complex(_get_precision())
|
510
517
|
|
511
518
|
|
512
519
|
def tolerance():
|
brainstate/environ_test.py
CHANGED
@@ -32,6 +32,10 @@ class TestEnviron(unittest.TestCase):
|
|
32
32
|
self.assertEqual(a.dtype, jnp.float32)
|
33
33
|
|
34
34
|
with bst.environ.context(precision=16):
|
35
|
+
a = bst.random.randn(1)
|
36
|
+
self.assertEqual(a.dtype, jnp.float16)
|
37
|
+
|
38
|
+
with bst.environ.context(precision='bf16'):
|
35
39
|
a = bst.random.randn(1)
|
36
40
|
self.assertEqual(a.dtype, jnp.bfloat16)
|
37
41
|
|
@@ -24,6 +24,7 @@ import jax.numpy as jnp
|
|
24
24
|
import numpy as np
|
25
25
|
from jax.interpreters import ad
|
26
26
|
|
27
|
+
from brainstate import environ
|
27
28
|
from brainstate._state import ParamState
|
28
29
|
from brainstate.augment import vmap
|
29
30
|
from brainstate.init import param
|
@@ -111,7 +112,7 @@ class FixedProb(Module):
|
|
111
112
|
|
112
113
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
113
114
|
if self.n_conn > 1:
|
114
|
-
|
115
|
+
r = event_fixed_prob(
|
115
116
|
spk,
|
116
117
|
self.weight.value,
|
117
118
|
self.indices,
|
@@ -123,7 +124,8 @@ class FixedProb(Module):
|
|
123
124
|
weight = self.weight.value
|
124
125
|
unit = u.get_unit(weight)
|
125
126
|
r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
126
|
-
|
127
|
+
r = u.maybe_decimal(u.Quantity(r, unit=unit))
|
128
|
+
return u.math.asarray(r, dtype=environ.dftype())
|
127
129
|
|
128
130
|
|
129
131
|
def event_fixed_prob(
|
@@ -128,4 +128,5 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
128
128
|
|
129
129
|
o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
130
130
|
self.assertTrue(jnp.allclose(o1, o2))
|
131
|
-
|
131
|
+
# assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
|
132
|
+
self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
|
@@ -12,9 +12,13 @@ from jax import tree_util
|
|
12
12
|
from jax.core import Primitive
|
13
13
|
from jax.interpreters import batching, ad
|
14
14
|
from jax.interpreters import xla, mlir
|
15
|
-
from jax.lib import xla_client
|
16
15
|
from jaxlib.hlo_helpers import custom_call
|
17
16
|
|
17
|
+
if jax.__version_info__ < (0, 4, 35):
|
18
|
+
from jax.lib import xla_client
|
19
|
+
else:
|
20
|
+
import jax.extend as je
|
21
|
+
|
18
22
|
numba_installed = importlib.util.find_spec('numba') is not None
|
19
23
|
|
20
24
|
__all__ = [
|
@@ -143,7 +147,10 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
|
143
147
|
xla_c_rule = cfunc(sig)(new_f)
|
144
148
|
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
|
145
149
|
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
|
146
|
-
|
150
|
+
if jax.__version_info__ < (0, 4, 35):
|
151
|
+
xla_client.register_custom_call_target(target_name, capsule, "cpu")
|
152
|
+
else:
|
153
|
+
je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
|
147
154
|
|
148
155
|
# call
|
149
156
|
return custom_call(
|
{brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250105
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -2,8 +2,8 @@ brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
|
|
2
2
|
brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
|
3
3
|
brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
|
-
brainstate/environ.py,sha256=
|
6
|
-
brainstate/environ_test.py,sha256=
|
5
|
+
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
6
|
+
brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
9
|
brainstate/surrogate.py,sha256=YaY6RJ6kzpuPXWFjaWsxWt2MzJfdm5v_jeOR8V_jPoU,48369
|
@@ -23,17 +23,17 @@ brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJC
|
|
23
23
|
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
|
-
brainstate/compile/_error_if_test.py,sha256=
|
26
|
+
brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
|
27
27
|
brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=XwMnKkMH0xTWB1f6GE4NQNK1R2GXTXCiVgulpkdIpc4,23308
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
33
|
brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
|
34
|
-
brainstate/compile/_make_jaxpr_test.py,sha256=
|
34
|
+
brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=eInZPjiqzYE6PWxl_or_lBthDNDO0Ov60Uz0DbuBbZQ,4620
|
36
|
-
brainstate/compile/_unvmap.py,sha256=
|
36
|
+
brainstate/compile/_unvmap.py,sha256=0i-NvCLDAUe-effJIIEPVsK4WTPbCDBTgw6AqRvq7mE,4163
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
38
|
brainstate/event/__init__.py,sha256=W0ZxbgrcFuYhWTl-GZ0UDoMGfsWmesvG4J_LbTBkex8,937
|
39
39
|
brainstate/event/_csr.py,sha256=QDccbgXUklE2iq1w6WdyaFspXY1165uA9UlPltX16OU,30365
|
@@ -41,14 +41,14 @@ brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,1
|
|
41
41
|
brainstate/event/_csr_mv_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
|
42
42
|
brainstate/event/_csr_mv_test.py,sha256=WQfAvp_3UeCUGAZjr3_aqQvrB-eYZcFEN4v1PBe9fUQ,4012
|
43
43
|
brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
|
44
|
-
brainstate/event/_fixedprob_mv.py,sha256=
|
44
|
+
brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
|
45
45
|
brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
|
46
|
-
brainstate/event/_fixedprob_mv_test.py,sha256=
|
46
|
+
brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
|
47
47
|
brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
|
48
48
|
brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
|
49
49
|
brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
|
50
50
|
brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
|
51
|
-
brainstate/event/_xla_custom_op.py,sha256=
|
51
|
+
brainstate/event/_xla_custom_op.py,sha256=f0OrO6CjsJOUNUCRPpHIRmsb_wgNEym0xBl1tcz8ij4,11016
|
52
52
|
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
53
53
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
54
54
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
|
|
137
137
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
138
138
|
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
139
139
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
144
|
-
brainstate-0.1.0.
|
140
|
+
brainstate-0.1.0.post20250105.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
141
|
+
brainstate-0.1.0.post20250105.dist-info/METADATA,sha256=Xec1GNBlHcignyvym-EHzU-JIOUuo3T-IUU2LoCO0sk,3533
|
142
|
+
brainstate-0.1.0.post20250105.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
143
|
+
brainstate-0.1.0.post20250105.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
144
|
+
brainstate-0.1.0.post20250105.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250102.dist-info → brainstate-0.1.0.post20250105.dist-info}/top_level.txt
RENAMED
File without changes
|