brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1262,11 +1262,7 @@ def treefy_states(
1262
1262
  if len(filters) == 0:
1263
1263
  return state_mapping
1264
1264
  else:
1265
- state_mappings = state_mapping.filter(*filters)
1266
- if len(filters) == 1:
1267
- return state_mappings[0]
1268
- else:
1269
- return state_mappings
1265
+ return state_mapping.filter(*filters)
1270
1266
 
1271
1267
 
1272
1268
  def _graph_update_dynamic(node: Any, state: Mapping) -> None:
brainstate/mixin.py CHANGED
@@ -380,6 +380,20 @@ class ParamDescriber(metaclass=NoSubclassMeta):
380
380
  merged_kwargs = {**self.kwargs, **kwargs}
381
381
  return self.cls(*self.args, *args, **merged_kwargs)
382
382
 
383
+ def __repr__(self):
384
+ """
385
+ Return a string representation of the ParamDescriber.
386
+
387
+ Returns
388
+ -------
389
+ str
390
+ A string showing the class and stored parameters.
391
+ """
392
+ args_str = ', '.join(repr(a) for a in self.args)
393
+ kwargs_str = ', '.join(f'{k}={v!r}' for k, v in self.kwargs.items())
394
+ all_params = ', '.join(filter(None, [args_str, kwargs_str]))
395
+ return f'ParamDescriber({self.cls.__name__}({all_params}))'
396
+
383
397
  def init(self, *args, **kwargs):
384
398
  """
385
399
  Alias for __call__, explicitly named for clarity.
brainstate/nn/__init__.py CHANGED
@@ -87,43 +87,52 @@ del (
87
87
 
88
88
  # Deprecated names that redirect to brainpy
89
89
  _DEPRECATED_NAMES = {
90
- 'SpikeTime': 'brainpy.SpikeTime',
91
- 'PoissonSpike': 'brainpy.PoissonSpike',
92
- 'PoissonEncoder': 'brainpy.PoissonEncoder',
93
- 'PoissonInput': 'brainpy.PoissonInput',
94
- 'poisson_input': 'brainpy.poisson_input',
95
- 'Neuron': 'brainpy.Neuron',
96
- 'IF': 'brainpy.IF',
97
- 'LIF': 'brainpy.LIF',
98
- 'LIFRef': 'brainpy.LIFRef',
99
- 'ALIF': 'brainpy.ALIF',
100
- 'LeakyRateReadout': 'brainpy.LeakyRateReadout',
101
- 'LeakySpikeReadout': 'brainpy.LeakySpikeReadout',
102
- 'STP': 'brainpy.STP',
103
- 'STD': 'brainpy.STD',
104
- 'Synapse': 'brainpy.Synapse',
105
- 'Expon': 'brainpy.Expon',
106
- 'DualExpon': 'brainpy.DualExpon',
107
- 'Alpha': 'brainpy.Alpha',
108
- 'AMPA': 'brainpy.AMPA',
109
- 'GABAa': 'brainpy.GABAa',
110
- 'COBA': 'brainpy.COBA',
111
- 'CUBA': 'brainpy.CUBA',
112
- 'MgBlock': 'brainpy.MgBlock',
113
- 'SynOut': 'brainpy.SynOut',
114
- 'AlignPostProj': 'brainpy.AlignPostProj',
115
- 'DeltaProj': 'brainpy.DeltaProj',
116
- 'CurrentProj': 'brainpy.CurrentProj',
117
- 'align_pre_projection': 'brainpy.align_pre_projection',
118
- 'Projection': 'brainpy.Projection',
119
- 'SymmetryGapJunction': 'brainpy.SymmetryGapJunction',
120
- 'AsymmetryGapJunction': 'brainpy.AsymmetryGapJunction',
90
+ 'SpikeTime': 'brainpy.state.SpikeTime',
91
+ 'PoissonSpike': 'brainpy.state.PoissonSpike',
92
+ 'PoissonEncoder': 'brainpy.state.PoissonEncoder',
93
+ 'PoissonInput': 'brainpy.state.PoissonInput',
94
+ 'poisson_input': 'brainpy.state.poisson_input',
95
+ 'Neuron': 'brainpy.state.Neuron',
96
+ 'IF': 'brainpy.state.IF',
97
+ 'LIF': 'brainpy.state.LIF',
98
+ 'LIFRef': 'brainpy.state.LIFRef',
99
+ 'ALIF': 'brainpy.state.ALIF',
100
+ 'LeakyRateReadout': 'brainpy.state.LeakyRateReadout',
101
+ 'LeakySpikeReadout': 'brainpy.state.LeakySpikeReadout',
102
+ 'STP': 'brainpy.state.STP',
103
+ 'STD': 'brainpy.state.STD',
104
+ 'Synapse': 'brainpy.state.Synapse',
105
+ 'Expon': 'brainpy.state.Expon',
106
+ 'DualExpon': 'brainpy.state.DualExpon',
107
+ 'Alpha': 'brainpy.state.Alpha',
108
+ 'AMPA': 'brainpy.state.AMPA',
109
+ 'GABAa': 'brainpy.state.GABAa',
110
+ 'COBA': 'brainpy.state.COBA',
111
+ 'CUBA': 'brainpy.state.CUBA',
112
+ 'MgBlock': 'brainpy.state.MgBlock',
113
+ 'SynOut': 'brainpy.state.SynOut',
114
+ 'AlignPostProj': 'brainpy.state.AlignPostProj',
115
+ 'DeltaProj': 'brainpy.state.DeltaProj',
116
+ 'CurrentProj': 'brainpy.state.CurrentProj',
117
+ 'align_pre_projection': 'brainpy.state.align_pre_projection',
118
+ 'Projection': 'brainpy.state.Projection',
119
+ 'SymmetryGapJunction': 'brainpy.state.SymmetryGapJunction',
120
+ 'AsymmetryGapJunction': 'brainpy.state.AsymmetryGapJunction',
121
121
  }
122
122
 
123
123
 
124
124
  def __getattr__(name: str):
125
+ import warnings
126
+ if name == 'DynamicsGroup':
127
+ warnings.warn(
128
+ f"'brainstate.nn.{name}' is deprecated and will be removed in a future version. "
129
+ f"Please use 'brainstate.nn.Module' instead.",
130
+ DeprecationWarning,
131
+ stacklevel=2
132
+ )
133
+ return Module
134
+
125
135
  if name in _DEPRECATED_NAMES:
126
- import warnings
127
136
  new_name = _DEPRECATED_NAMES[name]
128
137
  warnings.warn(
129
138
  f"'brainstate.nn.{name}' is deprecated and will be removed in a future version. "
@@ -133,5 +142,5 @@ def __getattr__(name: str):
133
142
  )
134
143
  # Import and return the actual brainpy object
135
144
  import brainpy
136
- return getattr(brainpy, name)
145
+ return getattr(brainpy.state, name)
137
146
  raise AttributeError(f"module 'brainstate.nn' has no attribute '{name}'")
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+
15
17
  import warnings
16
18
  from collections.abc import Sequence, Mapping
17
19
  from typing import Callable, TypeVar, Any
@@ -132,23 +132,3 @@ class TestFilterStates(unittest.TestCase):
132
132
 
133
133
  self.mock_module.states.assert_called_once_with(filter_obj)
134
134
  self.assertEqual(result, ['test1', 'test2'])
135
-
136
- def test_filter_states_dict_filters(self):
137
- """Test _filter_states with dictionary of filters.
138
-
139
- Note: Current implementation expects dict to be iterable as tuples,
140
- which suggests it's meant to be passed as a dict that yields tuples when iterated.
141
- This is likely a bug - should use filters.items().
142
- """
143
- # Skip this test as the current implementation has a bug
144
- self.skipTest("Current implementation has a bug in dict iteration")
145
-
146
- def test_filter_states_dict_invalid_axis(self):
147
- """Test _filter_states with non-integer axis in dictionary."""
148
- # Skip this test as the current implementation has a bug in dict iteration
149
- self.skipTest("Current implementation has a bug in dict iteration")
150
-
151
- def test_filter_states_dict_multiple_filters_same_axis(self):
152
- """Test _filter_states with multiple filters for the same axis."""
153
- # Skip this test as the current implementation has a bug in dict iteration
154
- self.skipTest("Current implementation has a bug in dict iteration")
brainstate/nn/_delay.py CHANGED
@@ -285,7 +285,7 @@ class Delay(Module):
285
285
  Returns:
286
286
  DelayAccess: An object that provides access to the delay data for the specified entry and time.
287
287
  """
288
- return DelayAccess(self, delay_time, entry=entry)
288
+ return DelayAccess(self, *delay_time, entry=entry)
289
289
 
290
290
  def at(self, entry: str) -> ArrayLike:
291
291
  """
@@ -146,16 +146,19 @@ class TestDropout1d(parameterized.TestCase):
146
146
 
147
147
 
148
148
  class TestDropout2d(parameterized.TestCase):
149
+ def setUp(self):
150
+ brainstate.random.seed(0)
149
151
 
150
152
  def test_dropout2d_basic(self):
151
153
  """Test basic Dropout2d functionality."""
152
- dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
153
- input_data = brainstate.random.randn(2, 3, 4, 5) # (N, C, H, W)
154
+ with brainstate.random.seed_context(42):
155
+ dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
156
+ input_data = brainstate.random.randn(2, 3, 4, 5) # (N, C, H, W)
154
157
 
155
- with brainstate.environ.context(fit=True):
156
- output_data = dropout_layer(input_data)
157
- self.assertEqual(input_data.shape, output_data.shape)
158
- self.assertTrue(np.any(output_data == 0))
158
+ with brainstate.environ.context(fit=True):
159
+ output_data = dropout_layer(input_data)
160
+ self.assertEqual(input_data.shape, output_data.shape)
161
+ self.assertTrue(np.any(output_data == 0))
159
162
 
160
163
  def test_dropout2d_channel_wise(self):
161
164
  """Test that Dropout2d applies dropout."""