brainstate 0.0.2.post20240826__py2.py3-none-any.whl → 0.0.2.post20240910__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +1 -1
- brainstate/_state.py +2 -0
- brainstate/nn/_projection/_align_post.py +28 -10
- brainstate/surrogate.py +54 -0
- brainstate/transform/_loop_collect_return.py +1 -1
- {brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/METADATA +3 -10
- {brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/RECORD +10 -10
- {brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/top_level.txt +0 -0
brainstate/_module.py
CHANGED
@@ -1597,6 +1597,6 @@ def _get_delay(delay_time, delay_step):
|
|
1597
1597
|
delay_time = delay_step * environ.get_dt()
|
1598
1598
|
else:
|
1599
1599
|
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
|
1600
|
-
assert isinstance(delay_time, (int, float))
|
1600
|
+
# assert isinstance(delay_time, (int, float))
|
1601
1601
|
delay_step = math.ceil(delay_time / environ.get_dt())
|
1602
1602
|
return delay_time, delay_step
|
brainstate/_state.py
CHANGED
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
from typing import Optional, Union
|
17
17
|
|
18
|
+
|
19
|
+
import brainunit as u
|
18
20
|
from brainstate._module import (register_delay_of_target,
|
19
21
|
Projection,
|
20
22
|
Module,
|
@@ -278,11 +280,19 @@ class FullProjAlignPostMg(Projection):
|
|
278
280
|
self.comm = comm
|
279
281
|
|
280
282
|
# delay initialization
|
281
|
-
if delay is not None
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
283
|
+
if delay is not None:
|
284
|
+
if isinstance(delay, u.Quantity):
|
285
|
+
has_delay = delay.mantissa > 0.
|
286
|
+
else:
|
287
|
+
has_delay = delay > 0.
|
288
|
+
if has_delay:
|
289
|
+
delay_cls = register_delay_of_target(pre)
|
290
|
+
delay_cls.register_entry(self.name, delay)
|
291
|
+
self.delay = delay_cls
|
292
|
+
self.has_delay = True
|
293
|
+
else:
|
294
|
+
self.delay = None
|
295
|
+
self.has_delay = False
|
286
296
|
else:
|
287
297
|
self.delay = None
|
288
298
|
self.has_delay = False
|
@@ -502,11 +512,19 @@ class FullProjAlignPost(Projection):
|
|
502
512
|
self.out = out
|
503
513
|
|
504
514
|
# delay initialization
|
505
|
-
if delay is not None
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
515
|
+
if delay is not None:
|
516
|
+
if isinstance(delay, u.Quantity):
|
517
|
+
has_delay = delay.mantissa > 0.
|
518
|
+
else:
|
519
|
+
has_delay = delay > 0.
|
520
|
+
if has_delay:
|
521
|
+
delay_cls = register_delay_of_target(pre)
|
522
|
+
delay_cls.register_entry(self.name, delay)
|
523
|
+
self.delay = delay_cls
|
524
|
+
self.has_delay = True
|
525
|
+
else:
|
526
|
+
self.delay = None
|
527
|
+
self.has_delay = False
|
510
528
|
else:
|
511
529
|
self.delay = None
|
512
530
|
self.has_delay = False
|
brainstate/surrogate.py
CHANGED
@@ -158,6 +158,9 @@ class Sigmoid(Surrogate):
|
|
158
158
|
def __repr__(self):
|
159
159
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
160
160
|
|
161
|
+
def __hash__(self):
|
162
|
+
return hash((self.__class__, self.alpha))
|
163
|
+
|
161
164
|
|
162
165
|
def sigmoid(
|
163
166
|
x: jax.Array,
|
@@ -243,6 +246,9 @@ class PiecewiseQuadratic(Surrogate):
|
|
243
246
|
def __repr__(self):
|
244
247
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
245
248
|
|
249
|
+
def __hash__(self):
|
250
|
+
return hash((self.__class__, self.alpha))
|
251
|
+
|
246
252
|
|
247
253
|
def piecewise_quadratic(
|
248
254
|
x: jax.Array,
|
@@ -339,6 +345,9 @@ class PiecewiseExp(Surrogate):
|
|
339
345
|
def __repr__(self):
|
340
346
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
341
347
|
|
348
|
+
def __hash__(self):
|
349
|
+
return hash((self.__class__, self.alpha))
|
350
|
+
|
342
351
|
|
343
352
|
def piecewise_exp(
|
344
353
|
x: jax.Array,
|
@@ -426,6 +435,9 @@ class SoftSign(Surrogate):
|
|
426
435
|
def __repr__(self):
|
427
436
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
428
437
|
|
438
|
+
def __hash__(self):
|
439
|
+
return hash((self.__class__, self.alpha))
|
440
|
+
|
429
441
|
|
430
442
|
def soft_sign(
|
431
443
|
x: jax.Array,
|
@@ -508,6 +520,9 @@ class Arctan(Surrogate):
|
|
508
520
|
def __repr__(self):
|
509
521
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
510
522
|
|
523
|
+
def __hash__(self):
|
524
|
+
return hash((self.__class__, self.alpha))
|
525
|
+
|
511
526
|
|
512
527
|
def arctan(
|
513
528
|
x: jax.Array,
|
@@ -589,6 +604,9 @@ class NonzeroSignLog(Surrogate):
|
|
589
604
|
def __repr__(self):
|
590
605
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
591
606
|
|
607
|
+
def __hash__(self):
|
608
|
+
return hash((self.__class__, self.alpha))
|
609
|
+
|
592
610
|
|
593
611
|
def nonzero_sign_log(
|
594
612
|
x: jax.Array,
|
@@ -683,6 +701,9 @@ class ERF(Surrogate):
|
|
683
701
|
def __repr__(self):
|
684
702
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
685
703
|
|
704
|
+
def __hash__(self):
|
705
|
+
return hash((self.__class__, self.alpha))
|
706
|
+
|
686
707
|
|
687
708
|
def erf(
|
688
709
|
x: jax.Array,
|
@@ -780,6 +801,9 @@ class PiecewiseLeakyRelu(Surrogate):
|
|
780
801
|
def __repr__(self):
|
781
802
|
return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
|
782
803
|
|
804
|
+
def __hash__(self):
|
805
|
+
return hash((self.__class__, self.c, self.w))
|
806
|
+
|
783
807
|
|
784
808
|
def piecewise_leaky_relu(
|
785
809
|
x: jax.Array,
|
@@ -898,6 +922,9 @@ class SquarewaveFourierSeries(Surrogate):
|
|
898
922
|
def __repr__(self):
|
899
923
|
return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
|
900
924
|
|
925
|
+
def __hash__(self):
|
926
|
+
return hash((self.__class__, self.n, self.t_period))
|
927
|
+
|
901
928
|
|
902
929
|
def squarewave_fourier_series(
|
903
930
|
x: jax.Array,
|
@@ -988,6 +1015,9 @@ class S2NN(Surrogate):
|
|
988
1015
|
def __repr__(self):
|
989
1016
|
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
|
990
1017
|
|
1018
|
+
def __hash__(self):
|
1019
|
+
return hash((self.__class__, self.alpha, self.beta, self.epsilon))
|
1020
|
+
|
991
1021
|
|
992
1022
|
def s2nn(
|
993
1023
|
x: jax.Array,
|
@@ -1089,6 +1119,9 @@ class QPseudoSpike(Surrogate):
|
|
1089
1119
|
def __repr__(self):
|
1090
1120
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1091
1121
|
|
1122
|
+
def __hash__(self):
|
1123
|
+
return hash((self.__class__, self.alpha))
|
1124
|
+
|
1092
1125
|
|
1093
1126
|
def q_pseudo_spike(
|
1094
1127
|
x: jax.Array,
|
@@ -1178,6 +1211,9 @@ class LeakyRelu(Surrogate):
|
|
1178
1211
|
def __repr__(self):
|
1179
1212
|
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
|
1180
1213
|
|
1214
|
+
def __hash__(self):
|
1215
|
+
return hash((self.__class__, self.alpha, self.beta))
|
1216
|
+
|
1181
1217
|
|
1182
1218
|
def leaky_relu(
|
1183
1219
|
x: jax.Array,
|
@@ -1277,6 +1313,9 @@ class LogTailedRelu(Surrogate):
|
|
1277
1313
|
def __repr__(self):
|
1278
1314
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1279
1315
|
|
1316
|
+
def __hash__(self):
|
1317
|
+
return hash((self.__class__, self.alpha))
|
1318
|
+
|
1280
1319
|
|
1281
1320
|
def log_tailed_relu(
|
1282
1321
|
x: jax.Array,
|
@@ -1368,6 +1407,9 @@ class ReluGrad(Surrogate):
|
|
1368
1407
|
def __repr__(self):
|
1369
1408
|
return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
|
1370
1409
|
|
1410
|
+
def __hash__(self):
|
1411
|
+
return hash((self.__class__, self.alpha, self.width))
|
1412
|
+
|
1371
1413
|
|
1372
1414
|
def relu_grad(
|
1373
1415
|
x: jax.Array,
|
@@ -1446,6 +1488,9 @@ class GaussianGrad(Surrogate):
|
|
1446
1488
|
def __repr__(self):
|
1447
1489
|
return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
|
1448
1490
|
|
1491
|
+
def __hash__(self):
|
1492
|
+
return hash((self.__class__, self.alpha, self.sigma))
|
1493
|
+
|
1449
1494
|
|
1450
1495
|
def gaussian_grad(
|
1451
1496
|
x: jax.Array,
|
@@ -1530,6 +1575,9 @@ class MultiGaussianGrad(Surrogate):
|
|
1530
1575
|
def __repr__(self):
|
1531
1576
|
return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
|
1532
1577
|
|
1578
|
+
def __hash__(self):
|
1579
|
+
return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
|
1580
|
+
|
1533
1581
|
|
1534
1582
|
def multi_gaussian_grad(
|
1535
1583
|
x: jax.Array,
|
@@ -1615,6 +1663,9 @@ class InvSquareGrad(Surrogate):
|
|
1615
1663
|
def __repr__(self):
|
1616
1664
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1617
1665
|
|
1666
|
+
def __hash__(self):
|
1667
|
+
return hash((self.__class__, self.alpha))
|
1668
|
+
|
1618
1669
|
|
1619
1670
|
def inv_square_grad(
|
1620
1671
|
x: jax.Array,
|
@@ -1685,6 +1736,9 @@ class SlayerGrad(Surrogate):
|
|
1685
1736
|
def __repr__(self):
|
1686
1737
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1687
1738
|
|
1739
|
+
def __hash__(self):
|
1740
|
+
return hash((self.__class__, self.alpha))
|
1741
|
+
|
1688
1742
|
|
1689
1743
|
def slayer_grad(
|
1690
1744
|
x: jax.Array,
|
@@ -418,7 +418,7 @@ def checkpointed_for_loop(
|
|
418
418
|
pbar: Optional[ProgressBar] = None,
|
419
419
|
):
|
420
420
|
"""
|
421
|
-
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version
|
421
|
+
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
|
422
422
|
|
423
423
|
Args:
|
424
424
|
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
{brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.post20240910
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BDP
|
@@ -78,16 +78,9 @@ pip install brainstate --upgrade
|
|
78
78
|
The official documentation is hosted on Read the Docs: [https://brainstate.readthedocs.io/](https://brainstate.readthedocs.io/)
|
79
79
|
|
80
80
|
|
81
|
-
## See also the BDP ecosystem
|
82
|
-
|
83
|
-
- [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming.
|
84
|
-
|
85
|
-
- [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming.
|
86
81
|
|
87
|
-
|
88
|
-
|
89
|
-
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks.
|
82
|
+
## See also the BDP ecosystem
|
90
83
|
|
91
|
-
|
84
|
+
We are building the BDP ecosystem: https://ecosystem-for-brain-dynamics.readthedocs.io/
|
92
85
|
|
93
86
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
|
2
|
-
brainstate/_module.py,sha256=
|
2
|
+
brainstate/_module.py,sha256=ULfItiqiQoIK1YUYfkasmyh8Rj4PoYJP7cxyuphEnIo,52463
|
3
3
|
brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
|
4
4
|
brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
|
5
|
-
brainstate/_state.py,sha256=
|
5
|
+
brainstate/_state.py,sha256=J5j5NujvqU3Ftd_m_u_3mz4xWw81mdYgdzltrdJSy8o,12162
|
6
6
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
7
7
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
8
8
|
brainstate/environ.py,sha256=k0p1oyi9jbsPfuvqrPL-_zgSd7VW3LRs0LboxlaaIfc,11806
|
@@ -10,7 +10,7 @@ brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
|
|
10
10
|
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
11
|
brainstate/random.py,sha256=pTZvTH06hv08_TpwzAWCqAjy-8oNGmB6-Jp6MKfkLaY,188087
|
12
12
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
13
|
-
brainstate/surrogate.py,sha256=
|
13
|
+
brainstate/surrogate.py,sha256=6AO79JOOs-X5x0FT0EDqO9lNtjJZAs26H4mljgpTvAw,45197
|
14
14
|
brainstate/typing.py,sha256=szCYee9R15YQfsEAQOx95_LqfrD9AYuE5dfTBTPd8sg,9165
|
15
15
|
brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
|
16
16
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
@@ -38,7 +38,7 @@ brainstate/nn/_rate_rnns.py,sha256=Cebhy57UWzfwrCfq0v2qLDegmb__mXL5ht750y4aTro,1
|
|
38
38
|
brainstate/nn/_readout.py,sha256=jsQwhVnrJICKw4wFq-Du2AORPb_XXz_tZ4cURcckU-E,4240
|
39
39
|
brainstate/nn/_synouts.py,sha256=gi3EyKlzt4UoyghwvNIr03r7YabZyl1idbq9aYG8zYM,4379
|
40
40
|
brainstate/nn/_projection/__init__.py,sha256=L6svNHTb8BDh2rdX2eYmcx_NdscSdKykkQbzpdCSkTA,1207
|
41
|
-
brainstate/nn/_projection/_align_post.py,sha256=
|
41
|
+
brainstate/nn/_projection/_align_post.py,sha256=S1huNBq3NkOfwrr7SXgTU6JvQk7KPVv86XwJ5iyvaBI,21106
|
42
42
|
brainstate/nn/_projection/_align_pre.py,sha256=_wjdj8muuv2_fSW9m3KBUVjNkBg28BUmz3qZ9IA1rUM,24597
|
43
43
|
brainstate/nn/_projection/_delta.py,sha256=KT8ySo3n_Q_7swzOH-ISDf0x9rjMkiv99H-vqeQZDR8,7122
|
44
44
|
brainstate/nn/_projection/_utils.py,sha256=UcmELOqsINgqJr7eC5BSNNteyZ--1lyGjhUTJfxyMmA,813
|
@@ -56,15 +56,15 @@ brainstate/transform/_error_if.py,sha256=0JThfFqt9B3K3H6mS84qecBS22yTi3-FPzviaYa
|
|
56
56
|
brainstate/transform/_error_if_test.py,sha256=kQZujlgr9bYnL-Vf7x4Zfc7jJk7rCLNVu-bsiry40dQ,1874
|
57
57
|
brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
|
58
58
|
brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
|
59
|
-
brainstate/transform/_loop_collect_return.py,sha256=
|
59
|
+
brainstate/transform/_loop_collect_return.py,sha256=8X6-3T3YoL_Buph9LiGASdrqPnRhsgsH9GQg1wcRos0,20800
|
60
60
|
brainstate/transform/_loop_no_collection.py,sha256=p2vHoNNesDH2cM7b5LgLzSv90M8iNQPkRZEl0jhf7yA,6476
|
61
61
|
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
62
62
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
63
63
|
brainstate/transform/_mapping.py,sha256=G9XUsD1xKLCprwwE0wv0gSXS0NYZ-ZIsv-PKKRlOoTA,3821
|
64
64
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
65
65
|
brainstate/transform/_unvmap.py,sha256=8Se_23QrwDdcJpFcUnnMgD6EP-4XylbhP9K5TDhW358,3311
|
66
|
-
brainstate-0.0.2.
|
67
|
-
brainstate-0.0.2.
|
68
|
-
brainstate-0.0.2.
|
69
|
-
brainstate-0.0.2.
|
70
|
-
brainstate-0.0.2.
|
66
|
+
brainstate-0.0.2.post20240910.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
67
|
+
brainstate-0.0.2.post20240910.dist-info/METADATA,sha256=gAdKRqW3BdiBR-xvAOdqwKiex6eXywKX0bieOW2_ZZQ,3311
|
68
|
+
brainstate-0.0.2.post20240910.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
69
|
+
brainstate-0.0.2.post20240910.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
70
|
+
brainstate-0.0.2.post20240910.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.0.2.post20240826.dist-info → brainstate-0.0.2.post20240910.dist-info}/top_level.txt
RENAMED
File without changes
|