brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
- from . import metrics
16
+ from . import init
17
+ from ._activations import *
18
+ from ._activations import __all__ as activation_all
18
19
  from ._collective_ops import *
19
20
  from ._collective_ops import __all__ as collective_ops_all
20
21
  from ._common import *
@@ -31,105 +32,106 @@ from ._elementwise import *
31
32
  from ._elementwise import __all__ as elementwise_all
32
33
  from ._embedding import *
33
34
  from ._embedding import __all__ as embed_all
35
+ from ._event_fixedprob import *
36
+ from ._event_fixedprob import __all__ as fixedprob_all
37
+ from ._event_linear import *
38
+ from ._event_linear import __all__ as linear_mv_all
34
39
  from ._exp_euler import *
35
40
  from ._exp_euler import __all__ as exp_euler_all
36
- from ._fixedprob import *
37
- from ._fixedprob import __all__ as fixedprob_all
38
- from ._inputs import *
39
- from ._inputs import __all__ as inputs_all
40
41
  from ._linear import *
41
42
  from ._linear import __all__ as linear_all
42
- from ._linear_mv import *
43
- from ._linear_mv import __all__ as linear_mv_all
44
- from ._ltp import *
45
- from ._ltp import __all__ as ltp_all
43
+ from ._metrics import *
44
+ from ._metrics import __all__ as metrics_all
46
45
  from ._module import *
47
46
  from ._module import __all__ as module_all
48
- from ._neuron import *
49
- from ._neuron import __all__ as dyn_neuron_all
50
47
  from ._normalizations import *
51
48
  from ._normalizations import __all__ as normalizations_all
52
- from ._others import *
53
- from ._others import __all__ as _others_all
49
+ from ._paddings import *
50
+ from ._paddings import __all__ as paddings_all
54
51
  from ._poolings import *
55
52
  from ._poolings import __all__ as poolings_all
56
- from ._projection import *
57
- from ._projection import __all__ as projection_all
58
- from ._rate_rnns import *
59
- from ._rate_rnns import __all__ as rate_rnns
60
- from ._readout import *
61
- from ._readout import __all__ as readout_all
62
- from ._stp import *
63
- from ._stp import __all__ as stp_all
64
- from ._synapse import *
65
- from ._synapse import __all__ as dyn_synapse_all
66
- from ._synaptic_projection import *
67
- from ._synaptic_projection import __all__ as _syn_proj_all
68
- from ._synouts import *
69
- from ._synouts import __all__ as synouts_all
53
+ from ._rnns import *
54
+ from ._rnns import __all__ as rate_rnns
70
55
  from ._utils import *
71
56
  from ._utils import __all__ as utils_all
72
57
 
73
- __all__ = (
74
- [
75
- 'metrics',
76
- ]
77
- + collective_ops_all
78
- + common_all
79
- + elementwise_all
80
- + module_all
81
- + exp_euler_all
82
- + utils_all
83
- + dyn_all
84
- + projection_all
85
- + state_delay_all
86
- + synouts_all
87
- + conv_all
88
- + linear_all
89
- + normalizations_all
90
- + poolings_all
91
- + fixedprob_all
92
- + linear_mv_all
93
- + embed_all
94
- + dropout_all
95
- + elementwise_all
96
- + dyn_neuron_all
97
- + dyn_synapse_all
98
- + inputs_all
99
- + rate_rnns
100
- + readout_all
101
- + stp_all
102
- + ltp_all
103
- + _syn_proj_all
104
- + _others_all
105
- )
58
+ __all__ = ['init'] + activation_all + metrics_all
59
+ __all__ = __all__ + collective_ops_all + common_all + elementwise_all + module_all + exp_euler_all
60
+ __all__ = __all__ + utils_all + dyn_all + state_delay_all + conv_all
61
+ __all__ = __all__ + linear_all + normalizations_all + paddings_all + poolings_all + fixedprob_all + linear_mv_all
62
+ __all__ = __all__ + embed_all + dropout_all + elementwise_all
63
+ __all__ = __all__ + rate_rnns
106
64
 
107
65
  del (
66
+ metrics_all,
67
+ activation_all,
108
68
  collective_ops_all,
109
69
  common_all,
110
70
  module_all,
111
71
  exp_euler_all,
112
72
  utils_all,
113
73
  dyn_all,
114
- projection_all,
115
74
  state_delay_all,
116
- synouts_all,
117
75
  conv_all,
118
76
  linear_all,
119
77
  normalizations_all,
78
+ paddings_all,
120
79
  poolings_all,
121
80
  embed_all,
122
81
  fixedprob_all,
123
82
  linear_mv_all,
124
83
  dropout_all,
125
84
  elementwise_all,
126
- dyn_neuron_all,
127
- dyn_synapse_all,
128
- inputs_all,
129
- readout_all,
130
85
  rate_rnns,
131
- stp_all,
132
- ltp_all,
133
- _syn_proj_all,
134
- _others_all,
135
86
  )
87
+
88
+ # Deprecated names that redirect to brainpy
89
+ _DEPRECATED_NAMES = {
90
+ 'SpikeTime': 'brainpy.SpikeTime',
91
+ 'PoissonSpike': 'brainpy.PoissonSpike',
92
+ 'PoissonEncoder': 'brainpy.PoissonEncoder',
93
+ 'PoissonInput': 'brainpy.PoissonInput',
94
+ 'poisson_input': 'brainpy.poisson_input',
95
+ 'Neuron': 'brainpy.Neuron',
96
+ 'IF': 'brainpy.IF',
97
+ 'LIF': 'brainpy.LIF',
98
+ 'LIFRef': 'brainpy.LIFRef',
99
+ 'ALIF': 'brainpy.ALIF',
100
+ 'LeakyRateReadout': 'brainpy.LeakyRateReadout',
101
+ 'LeakySpikeReadout': 'brainpy.LeakySpikeReadout',
102
+ 'STP': 'brainpy.STP',
103
+ 'STD': 'brainpy.STD',
104
+ 'Synapse': 'brainpy.Synapse',
105
+ 'Expon': 'brainpy.Expon',
106
+ 'DualExpon': 'brainpy.DualExpon',
107
+ 'Alpha': 'brainpy.Alpha',
108
+ 'AMPA': 'brainpy.AMPA',
109
+ 'GABAa': 'brainpy.GABAa',
110
+ 'COBA': 'brainpy.COBA',
111
+ 'CUBA': 'brainpy.CUBA',
112
+ 'MgBlock': 'brainpy.MgBlock',
113
+ 'SynOut': 'brainpy.SynOut',
114
+ 'AlignPostProj': 'brainpy.AlignPostProj',
115
+ 'DeltaProj': 'brainpy.DeltaProj',
116
+ 'CurrentProj': 'brainpy.CurrentProj',
117
+ 'align_pre_projection': 'brainpy.align_pre_projection',
118
+ 'Projection': 'brainpy.Projection',
119
+ 'SymmetryGapJunction': 'brainpy.SymmetryGapJunction',
120
+ 'AsymmetryGapJunction': 'brainpy.AsymmetryGapJunction',
121
+ }
122
+
123
+
124
+ def __getattr__(name: str):
125
+ if name in _DEPRECATED_NAMES:
126
+ import warnings
127
+ new_name = _DEPRECATED_NAMES[name]
128
+ warnings.warn(
129
+ f"'brainstate.nn.{name}' is deprecated and will be removed in a future version. "
130
+ f"Please use '{new_name}' instead.",
131
+ DeprecationWarning,
132
+ stacklevel=2
133
+ )
134
+ # Import and return the actual brainpy object
135
+ import brainpy
136
+ return getattr(brainpy, name)
137
+ raise AttributeError(f"module 'brainstate.nn' has no attribute '{name}'")