brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,696 @@
|
|
1
|
+
# Copyright 2024 BrainState Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Comprehensive tests for pretty_repr module.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import dataclasses
|
20
|
+
import unittest
|
21
|
+
from typing import Iterator, Union
|
22
|
+
|
23
|
+
from brainstate.util._pretty_repr import (
|
24
|
+
PrettyType,
|
25
|
+
PrettyAttr,
|
26
|
+
PrettyRepr,
|
27
|
+
pretty_repr_elem,
|
28
|
+
pretty_repr_object,
|
29
|
+
MappingReprMixin,
|
30
|
+
PrettyMapping,
|
31
|
+
PrettyReprContext,
|
32
|
+
yield_unique_pretty_repr_items,
|
33
|
+
_default_repr_object,
|
34
|
+
_default_repr_attr,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
class TestPrettyType(unittest.TestCase):
|
39
|
+
"""Test cases for PrettyType dataclass."""
|
40
|
+
|
41
|
+
def test_default_values(self):
|
42
|
+
"""Test PrettyType with default values."""
|
43
|
+
pt = PrettyType(type='MyClass')
|
44
|
+
self.assertEqual(pt.type, 'MyClass')
|
45
|
+
self.assertEqual(pt.start, '(')
|
46
|
+
self.assertEqual(pt.end, ')')
|
47
|
+
self.assertEqual(pt.value_sep, '=')
|
48
|
+
self.assertEqual(pt.elem_indent, ' ')
|
49
|
+
self.assertEqual(pt.empty_repr, '')
|
50
|
+
|
51
|
+
def test_custom_values(self):
|
52
|
+
"""Test PrettyType with custom values."""
|
53
|
+
pt = PrettyType(
|
54
|
+
type=dict,
|
55
|
+
start='{',
|
56
|
+
end='}',
|
57
|
+
value_sep=': ',
|
58
|
+
elem_indent=' ',
|
59
|
+
empty_repr='<empty>'
|
60
|
+
)
|
61
|
+
self.assertEqual(pt.type, dict)
|
62
|
+
self.assertEqual(pt.start, '{')
|
63
|
+
self.assertEqual(pt.end, '}')
|
64
|
+
self.assertEqual(pt.value_sep, ': ')
|
65
|
+
self.assertEqual(pt.elem_indent, ' ')
|
66
|
+
self.assertEqual(pt.empty_repr, '<empty>')
|
67
|
+
|
68
|
+
def test_type_can_be_string_or_class(self):
|
69
|
+
"""Test that type can be either string or class."""
|
70
|
+
pt1 = PrettyType(type='StringType')
|
71
|
+
self.assertIsInstance(pt1.type, str)
|
72
|
+
|
73
|
+
pt2 = PrettyType(type=list)
|
74
|
+
self.assertEqual(pt2.type, list)
|
75
|
+
|
76
|
+
|
77
|
+
class TestPrettyAttr(unittest.TestCase):
|
78
|
+
"""Test cases for PrettyAttr dataclass."""
|
79
|
+
|
80
|
+
def test_default_values(self):
|
81
|
+
"""Test PrettyAttr with default values."""
|
82
|
+
pa = PrettyAttr(key='name', value='test')
|
83
|
+
self.assertEqual(pa.key, 'name')
|
84
|
+
self.assertEqual(pa.value, 'test')
|
85
|
+
self.assertEqual(pa.start, '')
|
86
|
+
self.assertEqual(pa.end, '')
|
87
|
+
|
88
|
+
def test_custom_values(self):
|
89
|
+
"""Test PrettyAttr with custom values."""
|
90
|
+
pa = PrettyAttr(key='count', value=42, start='[', end=']')
|
91
|
+
self.assertEqual(pa.key, 'count')
|
92
|
+
self.assertEqual(pa.value, 42)
|
93
|
+
self.assertEqual(pa.start, '[')
|
94
|
+
self.assertEqual(pa.end, ']')
|
95
|
+
|
96
|
+
def test_value_types(self):
|
97
|
+
"""Test PrettyAttr with various value types."""
|
98
|
+
pa1 = PrettyAttr('str_value', 'string')
|
99
|
+
self.assertEqual(pa1.value, 'string')
|
100
|
+
|
101
|
+
pa2 = PrettyAttr('int_value', 123)
|
102
|
+
self.assertEqual(pa2.value, 123)
|
103
|
+
|
104
|
+
pa3 = PrettyAttr('list_value', [1, 2, 3])
|
105
|
+
self.assertEqual(pa3.value, [1, 2, 3])
|
106
|
+
|
107
|
+
pa4 = PrettyAttr('dict_value', {'a': 1})
|
108
|
+
self.assertEqual(pa4.value, {'a': 1})
|
109
|
+
|
110
|
+
|
111
|
+
class SimplePrettyRepr(PrettyRepr):
|
112
|
+
"""Simple implementation of PrettyRepr for testing."""
|
113
|
+
|
114
|
+
def __init__(self, value, name='SimpleObject'):
|
115
|
+
self.value = value
|
116
|
+
self.name = name
|
117
|
+
|
118
|
+
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
119
|
+
yield PrettyType(type=self.name)
|
120
|
+
yield PrettyAttr('value', self.value)
|
121
|
+
|
122
|
+
|
123
|
+
class TestPrettyRepr(unittest.TestCase):
|
124
|
+
"""Test cases for PrettyRepr abstract class."""
|
125
|
+
|
126
|
+
def test_simple_repr(self):
|
127
|
+
"""Test simple pretty representation."""
|
128
|
+
obj = SimplePrettyRepr(42)
|
129
|
+
result = repr(obj)
|
130
|
+
self.assertIn('SimpleObject', result)
|
131
|
+
self.assertIn('value=42', result)
|
132
|
+
|
133
|
+
def test_custom_type_config(self):
|
134
|
+
"""Test PrettyRepr with custom type configuration."""
|
135
|
+
|
136
|
+
class CustomRepr(PrettyRepr):
|
137
|
+
def __init__(self, data):
|
138
|
+
self.data = data
|
139
|
+
|
140
|
+
def __pretty_repr__(self):
|
141
|
+
yield PrettyType(type='CustomObject', start='<', end='>', value_sep=' -> ')
|
142
|
+
yield PrettyAttr('data', self.data)
|
143
|
+
|
144
|
+
obj = CustomRepr({'key': 'value'})
|
145
|
+
result = repr(obj)
|
146
|
+
self.assertIn('CustomObject<', result)
|
147
|
+
self.assertIn('data -> ', result)
|
148
|
+
self.assertIn('>', result)
|
149
|
+
|
150
|
+
def test_multiple_attributes(self):
|
151
|
+
"""Test PrettyRepr with multiple attributes."""
|
152
|
+
|
153
|
+
class MultiAttrRepr(PrettyRepr):
|
154
|
+
def __init__(self, a, b, c):
|
155
|
+
self.a = a
|
156
|
+
self.b = b
|
157
|
+
self.c = c
|
158
|
+
|
159
|
+
def __pretty_repr__(self):
|
160
|
+
yield PrettyType(type=self.__class__)
|
161
|
+
yield PrettyAttr('a', self.a)
|
162
|
+
yield PrettyAttr('b', self.b)
|
163
|
+
yield PrettyAttr('c', self.c)
|
164
|
+
|
165
|
+
obj = MultiAttrRepr(1, 'two', [3])
|
166
|
+
result = repr(obj)
|
167
|
+
self.assertIn('MultiAttrRepr', result)
|
168
|
+
self.assertIn('a=1', result)
|
169
|
+
self.assertIn("b=two", result) # String value is not re-quoted
|
170
|
+
self.assertIn('c=[3]', result)
|
171
|
+
|
172
|
+
def test_empty_object(self):
|
173
|
+
"""Test PrettyRepr with no attributes."""
|
174
|
+
|
175
|
+
class EmptyRepr(PrettyRepr):
|
176
|
+
def __pretty_repr__(self):
|
177
|
+
yield PrettyType(type='EmptyObject', empty_repr='<no data>')
|
178
|
+
|
179
|
+
obj = EmptyRepr()
|
180
|
+
result = repr(obj)
|
181
|
+
self.assertIn('EmptyObject', result)
|
182
|
+
self.assertIn('<no data>', result)
|
183
|
+
|
184
|
+
|
185
|
+
class TestPrettyReprElem(unittest.TestCase):
|
186
|
+
"""Test cases for pretty_repr_elem function."""
|
187
|
+
|
188
|
+
def test_basic_elem(self):
|
189
|
+
"""Test basic element formatting."""
|
190
|
+
pt = PrettyType(type='Test')
|
191
|
+
elem = PrettyAttr('key', 'value')
|
192
|
+
result = pretty_repr_elem(pt, elem)
|
193
|
+
# Value is already a string, so it's not quoted again
|
194
|
+
self.assertEqual(result, " key=value")
|
195
|
+
|
196
|
+
def test_elem_with_custom_indent(self):
|
197
|
+
"""Test element with custom indentation."""
|
198
|
+
pt = PrettyType(type='Test', elem_indent=' ')
|
199
|
+
elem = PrettyAttr('key', 123)
|
200
|
+
result = pretty_repr_elem(pt, elem)
|
201
|
+
self.assertEqual(result, " key=123")
|
202
|
+
|
203
|
+
def test_elem_with_custom_separator(self):
|
204
|
+
"""Test element with custom value separator."""
|
205
|
+
pt = PrettyType(type='Test', value_sep=': ')
|
206
|
+
elem = PrettyAttr('key', 'value')
|
207
|
+
result = pretty_repr_elem(pt, elem)
|
208
|
+
self.assertEqual(result, " key: value")
|
209
|
+
|
210
|
+
def test_elem_with_start_end(self):
|
211
|
+
"""Test element with start and end markers."""
|
212
|
+
pt = PrettyType(type='Test')
|
213
|
+
elem = PrettyAttr('key', 'value', start='[', end=']')
|
214
|
+
result = pretty_repr_elem(pt, elem)
|
215
|
+
self.assertEqual(result, " [key=value]")
|
216
|
+
|
217
|
+
def test_elem_with_multiline_value(self):
|
218
|
+
"""Test element with multiline value."""
|
219
|
+
pt = PrettyType(type='Test')
|
220
|
+
elem = PrettyAttr('key', 'line1\nline2\nline3')
|
221
|
+
result = pretty_repr_elem(pt, elem)
|
222
|
+
expected = " key=line1\n line2\n line3"
|
223
|
+
self.assertEqual(result, expected)
|
224
|
+
|
225
|
+
def test_elem_invalid_type(self):
|
226
|
+
"""Test that non-PrettyAttr raises TypeError."""
|
227
|
+
pt = PrettyType(type='Test')
|
228
|
+
with self.assertRaises(TypeError) as cm:
|
229
|
+
pretty_repr_elem(pt, "not a PrettyAttr")
|
230
|
+
self.assertIn("Item must be Elem", str(cm.exception))
|
231
|
+
|
232
|
+
|
233
|
+
class TestPrettyReprObject(unittest.TestCase):
|
234
|
+
"""Test cases for pretty_repr_object function."""
|
235
|
+
|
236
|
+
def test_valid_object(self):
|
237
|
+
"""Test with valid PrettyRepr object."""
|
238
|
+
obj = SimplePrettyRepr(42, 'TestObject')
|
239
|
+
result = pretty_repr_object(obj)
|
240
|
+
self.assertIn('TestObject', result)
|
241
|
+
self.assertIn('value=42', result)
|
242
|
+
|
243
|
+
def test_invalid_object(self):
|
244
|
+
"""Test that non-PrettyRepr object raises TypeError."""
|
245
|
+
with self.assertRaises(TypeError) as cm:
|
246
|
+
pretty_repr_object("not a PrettyRepr")
|
247
|
+
self.assertIn("is not representable", str(cm.exception))
|
248
|
+
|
249
|
+
def test_invalid_first_item(self):
|
250
|
+
"""Test that invalid first item raises TypeError."""
|
251
|
+
|
252
|
+
class InvalidRepr(PrettyRepr):
|
253
|
+
def __pretty_repr__(self):
|
254
|
+
yield PrettyAttr('key', 'value') # Should yield PrettyType first
|
255
|
+
|
256
|
+
obj = InvalidRepr()
|
257
|
+
with self.assertRaises(TypeError) as cm:
|
258
|
+
pretty_repr_object(obj)
|
259
|
+
self.assertIn("First item must be PrettyType", str(cm.exception))
|
260
|
+
|
261
|
+
def test_empty_representation(self):
|
262
|
+
"""Test object with no attributes."""
|
263
|
+
|
264
|
+
class EmptyRepr(PrettyRepr):
|
265
|
+
def __pretty_repr__(self):
|
266
|
+
yield PrettyType(type='Empty', empty_repr='∅')
|
267
|
+
|
268
|
+
obj = EmptyRepr()
|
269
|
+
result = pretty_repr_object(obj)
|
270
|
+
self.assertEqual(result, 'Empty(∅)')
|
271
|
+
|
272
|
+
def test_complex_nested_formatting(self):
|
273
|
+
"""Test complex nested formatting."""
|
274
|
+
|
275
|
+
class ComplexRepr(PrettyRepr):
|
276
|
+
def __pretty_repr__(self):
|
277
|
+
yield PrettyType(
|
278
|
+
type='Complex',
|
279
|
+
start='{\n',
|
280
|
+
end='\n}',
|
281
|
+
elem_indent=' ',
|
282
|
+
value_sep=' => '
|
283
|
+
)
|
284
|
+
yield PrettyAttr('first', 'value1')
|
285
|
+
yield PrettyAttr('second', {'nested': 'dict'})
|
286
|
+
|
287
|
+
obj = ComplexRepr()
|
288
|
+
result = pretty_repr_object(obj)
|
289
|
+
self.assertIn('Complex', result)
|
290
|
+
self.assertIn('first => ', result)
|
291
|
+
self.assertIn('second => ', result)
|
292
|
+
|
293
|
+
|
294
|
+
class TestMappingReprMixin(unittest.TestCase):
|
295
|
+
"""Test cases for MappingReprMixin."""
|
296
|
+
|
297
|
+
def test_basic_mapping(self):
|
298
|
+
"""Test basic mapping representation."""
|
299
|
+
|
300
|
+
class MyMapping(dict, MappingReprMixin):
|
301
|
+
pass
|
302
|
+
|
303
|
+
m = MyMapping({'a': 1, 'b': 2})
|
304
|
+
# Get the pretty repr items - MappingReprMixin only provides __pretty_repr__
|
305
|
+
# but needs to be mixed with a dict-like class
|
306
|
+
items = list(m.__pretty_repr__())
|
307
|
+
|
308
|
+
# Check first item is PrettyType
|
309
|
+
self.assertIsInstance(items[0], PrettyType)
|
310
|
+
self.assertEqual(items[0].value_sep, ': ')
|
311
|
+
self.assertEqual(items[0].start, '{')
|
312
|
+
self.assertEqual(items[0].end, '}')
|
313
|
+
|
314
|
+
# Check that we have the expected number of items
|
315
|
+
self.assertEqual(len(items), 3) # PrettyType + 2 attrs
|
316
|
+
|
317
|
+
# Check attributes
|
318
|
+
attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
|
319
|
+
self.assertEqual(len(attr_items), 2)
|
320
|
+
|
321
|
+
# Keys should be repr'd (with quotes for strings)
|
322
|
+
keys = [item.key for item in attr_items]
|
323
|
+
self.assertIn("'a'", keys)
|
324
|
+
self.assertIn("'b'", keys)
|
325
|
+
|
326
|
+
def test_empty_mapping(self):
|
327
|
+
"""Test empty mapping representation."""
|
328
|
+
|
329
|
+
class MyMapping(dict, MappingReprMixin):
|
330
|
+
pass
|
331
|
+
|
332
|
+
m = MyMapping()
|
333
|
+
items = list(m.__pretty_repr__())
|
334
|
+
self.assertEqual(len(items), 1) # Only PrettyType, no attributes
|
335
|
+
|
336
|
+
|
337
|
+
class TestPrettyMapping(unittest.TestCase):
|
338
|
+
"""Test cases for PrettyMapping class."""
|
339
|
+
|
340
|
+
def test_basic_pretty_mapping(self):
|
341
|
+
"""Test basic PrettyMapping."""
|
342
|
+
pm = PrettyMapping({'x': 10, 'y': 20})
|
343
|
+
result = repr(pm)
|
344
|
+
self.assertIn("'x': 10", result)
|
345
|
+
self.assertIn("'y': 20", result)
|
346
|
+
|
347
|
+
def test_pretty_mapping_with_type_name(self):
|
348
|
+
"""Test PrettyMapping with custom type name."""
|
349
|
+
pm = PrettyMapping({'a': 1}, type_name='MyDict')
|
350
|
+
result = repr(pm)
|
351
|
+
self.assertIn('MyDict', result)
|
352
|
+
self.assertIn("'a': 1", result)
|
353
|
+
|
354
|
+
def test_empty_pretty_mapping(self):
|
355
|
+
"""Test empty PrettyMapping."""
|
356
|
+
pm = PrettyMapping({})
|
357
|
+
result = repr(pm)
|
358
|
+
self.assertIn('{', result)
|
359
|
+
self.assertIn('}', result)
|
360
|
+
|
361
|
+
def test_nested_mapping(self):
|
362
|
+
"""Test PrettyMapping with nested values."""
|
363
|
+
pm = PrettyMapping({
|
364
|
+
'simple': 42,
|
365
|
+
'nested': {'inner': 'value'},
|
366
|
+
'list': [1, 2, 3]
|
367
|
+
})
|
368
|
+
result = repr(pm)
|
369
|
+
self.assertIn("'simple': 42", result)
|
370
|
+
self.assertIn("'nested':", result)
|
371
|
+
self.assertIn("'list': [1, 2, 3]", result)
|
372
|
+
|
373
|
+
|
374
|
+
class TestPrettyReprContext(unittest.TestCase):
|
375
|
+
"""Test cases for PrettyReprContext."""
|
376
|
+
|
377
|
+
def test_initial_state(self):
|
378
|
+
"""Test initial state of context."""
|
379
|
+
ctx = PrettyReprContext()
|
380
|
+
self.assertIsNone(ctx.seen_modules_repr)
|
381
|
+
|
382
|
+
def test_thread_local_behavior(self):
|
383
|
+
"""Test that context is thread-local."""
|
384
|
+
ctx1 = PrettyReprContext()
|
385
|
+
ctx2 = PrettyReprContext()
|
386
|
+
|
387
|
+
ctx1.seen_modules_repr = {'test': 1}
|
388
|
+
self.assertIsNone(ctx2.seen_modules_repr)
|
389
|
+
|
390
|
+
|
391
|
+
class TestYieldUniquePrettyReprItems(unittest.TestCase):
|
392
|
+
"""Test cases for yield_unique_pretty_repr_items function."""
|
393
|
+
|
394
|
+
def test_basic_usage(self):
|
395
|
+
"""Test basic usage with simple object."""
|
396
|
+
|
397
|
+
@dataclasses.dataclass
|
398
|
+
class SimpleObject:
|
399
|
+
value: int
|
400
|
+
|
401
|
+
obj = SimpleObject(42)
|
402
|
+
items = list(yield_unique_pretty_repr_items(obj))
|
403
|
+
|
404
|
+
# Should yield PrettyType first
|
405
|
+
self.assertIsInstance(items[0], PrettyType)
|
406
|
+
self.assertEqual(items[0].type, SimpleObject)
|
407
|
+
|
408
|
+
# Should yield attribute
|
409
|
+
attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
|
410
|
+
self.assertTrue(any(item.key == 'value' for item in attr_items))
|
411
|
+
|
412
|
+
def test_custom_repr_functions(self):
|
413
|
+
"""Test with custom repr functions."""
|
414
|
+
|
415
|
+
def custom_repr_object(node):
|
416
|
+
yield PrettyType(type='CustomType', start='<', end='>')
|
417
|
+
|
418
|
+
def custom_repr_attr(node):
|
419
|
+
yield PrettyAttr('custom_attr', 'custom_value')
|
420
|
+
|
421
|
+
obj = object()
|
422
|
+
items = list(yield_unique_pretty_repr_items(
|
423
|
+
obj,
|
424
|
+
repr_object=custom_repr_object,
|
425
|
+
repr_attr=custom_repr_attr
|
426
|
+
))
|
427
|
+
|
428
|
+
self.assertIsInstance(items[0], PrettyType)
|
429
|
+
self.assertEqual(items[0].type, 'CustomType')
|
430
|
+
|
431
|
+
attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
|
432
|
+
self.assertTrue(any(item.key == 'custom_attr' for item in attr_items))
|
433
|
+
|
434
|
+
def test_circular_reference_handling(self):
|
435
|
+
"""Test handling of circular references."""
|
436
|
+
|
437
|
+
class Node:
|
438
|
+
def __init__(self, value):
|
439
|
+
self.value = value
|
440
|
+
self.next = None
|
441
|
+
|
442
|
+
# Create circular reference
|
443
|
+
node1 = Node(1)
|
444
|
+
node2 = Node(2)
|
445
|
+
node1.next = node2
|
446
|
+
node2.next = node1
|
447
|
+
|
448
|
+
# Test that within same context, circular reference is detected
|
449
|
+
from brainstate.util._pretty_repr import CONTEXT
|
450
|
+
|
451
|
+
# Clean start
|
452
|
+
CONTEXT.seen_modules_repr = None
|
453
|
+
|
454
|
+
# Set up context to track seen objects
|
455
|
+
CONTEXT.seen_modules_repr = {}
|
456
|
+
|
457
|
+
# First pass - node1 will be added to seen
|
458
|
+
items1 = list(yield_unique_pretty_repr_items(node1))
|
459
|
+
|
460
|
+
# Second pass - should detect node1 is already seen
|
461
|
+
items2 = list(yield_unique_pretty_repr_items(node1))
|
462
|
+
|
463
|
+
# Check that second pass detected circular reference
|
464
|
+
type_items = [item for item in items2 if isinstance(item, PrettyType)]
|
465
|
+
self.assertTrue(len(type_items) > 0)
|
466
|
+
self.assertTrue(any(item.empty_repr == '...' for item in type_items))
|
467
|
+
|
468
|
+
# Clean up
|
469
|
+
CONTEXT.seen_modules_repr = None
|
470
|
+
|
471
|
+
def test_context_cleanup(self):
|
472
|
+
"""Test that context is properly cleaned up."""
|
473
|
+
from brainstate.util._pretty_repr import CONTEXT
|
474
|
+
|
475
|
+
# Clean up any previous state
|
476
|
+
CONTEXT.seen_modules_repr = None
|
477
|
+
|
478
|
+
# Use a class instance that has __dict__
|
479
|
+
class TestObj:
|
480
|
+
def __init__(self):
|
481
|
+
self.test = 'value'
|
482
|
+
|
483
|
+
obj = TestObj()
|
484
|
+
list(yield_unique_pretty_repr_items(obj))
|
485
|
+
|
486
|
+
# Context should be cleaned up after
|
487
|
+
self.assertIsNone(CONTEXT.seen_modules_repr)
|
488
|
+
|
489
|
+
def test_nested_calls(self):
|
490
|
+
"""Test nested calls don't recreate context."""
|
491
|
+
from brainstate.util._pretty_repr import CONTEXT
|
492
|
+
|
493
|
+
# Clean up any previous state
|
494
|
+
CONTEXT.seen_modules_repr = None
|
495
|
+
|
496
|
+
class Outer:
|
497
|
+
def __init__(self):
|
498
|
+
self.inner = Inner()
|
499
|
+
|
500
|
+
class Inner:
|
501
|
+
def __init__(self):
|
502
|
+
self.value = 42
|
503
|
+
|
504
|
+
def repr_outer_attr(node):
|
505
|
+
# This will trigger nested yield_unique_pretty_repr_items
|
506
|
+
for item in yield_unique_pretty_repr_items(node.inner):
|
507
|
+
pass
|
508
|
+
yield PrettyAttr('inner', node.inner)
|
509
|
+
|
510
|
+
outer = Outer()
|
511
|
+
|
512
|
+
# This should not error and should handle nested calls properly
|
513
|
+
items = list(yield_unique_pretty_repr_items(
|
514
|
+
outer,
|
515
|
+
repr_attr=repr_outer_attr
|
516
|
+
))
|
517
|
+
|
518
|
+
# Clean up should happen
|
519
|
+
self.assertIsNone(CONTEXT.seen_modules_repr)
|
520
|
+
|
521
|
+
|
522
|
+
class TestDefaultReprFunctions(unittest.TestCase):
|
523
|
+
"""Test cases for default repr functions."""
|
524
|
+
|
525
|
+
def test_default_repr_object(self):
|
526
|
+
"""Test _default_repr_object function."""
|
527
|
+
|
528
|
+
class MyClass:
|
529
|
+
pass
|
530
|
+
|
531
|
+
obj = MyClass()
|
532
|
+
items = list(_default_repr_object(obj))
|
533
|
+
|
534
|
+
self.assertEqual(len(items), 1)
|
535
|
+
self.assertIsInstance(items[0], PrettyType)
|
536
|
+
self.assertEqual(items[0].type, MyClass)
|
537
|
+
|
538
|
+
def test_default_repr_attr(self):
|
539
|
+
"""Test _default_repr_attr function."""
|
540
|
+
|
541
|
+
class MyClass:
|
542
|
+
def __init__(self):
|
543
|
+
self.public_attr = 'public'
|
544
|
+
self._private_attr = 'private'
|
545
|
+
self.__dunder_attr = 'dunder'
|
546
|
+
self.number = 42
|
547
|
+
self.list_attr = [1, 2, 3]
|
548
|
+
|
549
|
+
obj = MyClass()
|
550
|
+
items = list(_default_repr_attr(obj))
|
551
|
+
|
552
|
+
# Should include public attributes
|
553
|
+
attr_keys = {item.key for item in items}
|
554
|
+
self.assertIn('public_attr', attr_keys)
|
555
|
+
self.assertIn('number', attr_keys)
|
556
|
+
self.assertIn('list_attr', attr_keys)
|
557
|
+
|
558
|
+
# Should exclude private attributes
|
559
|
+
self.assertNotIn('_private_attr', attr_keys)
|
560
|
+
self.assertNotIn('__dunder_attr', attr_keys)
|
561
|
+
|
562
|
+
# Check values are repr'd
|
563
|
+
public_item = next(item for item in items if item.key == 'public_attr')
|
564
|
+
self.assertEqual(public_item.value, "'public'")
|
565
|
+
|
566
|
+
number_item = next(item for item in items if item.key == 'number')
|
567
|
+
self.assertEqual(number_item.value, '42')
|
568
|
+
|
569
|
+
def test_default_repr_attr_no_vars(self):
|
570
|
+
"""Test _default_repr_attr with object that has no __dict__."""
|
571
|
+
|
572
|
+
class NoVars:
|
573
|
+
__slots__ = ('x', 'y')
|
574
|
+
|
575
|
+
def __init__(self):
|
576
|
+
self.x = 1
|
577
|
+
self.y = 2
|
578
|
+
|
579
|
+
obj = NoVars()
|
580
|
+
# vars() will raise TypeError for objects without __dict__
|
581
|
+
with self.assertRaises(TypeError):
|
582
|
+
list(_default_repr_attr(obj))
|
583
|
+
|
584
|
+
|
585
|
+
class TestIntegration(unittest.TestCase):
|
586
|
+
"""Integration tests for the pretty_repr module."""
|
587
|
+
|
588
|
+
def test_complex_nested_structure(self):
|
589
|
+
"""Test complex nested structure representation."""
|
590
|
+
|
591
|
+
class Container(PrettyRepr):
|
592
|
+
def __init__(self, name, children=None):
|
593
|
+
self.name = name
|
594
|
+
self.children = children or []
|
595
|
+
|
596
|
+
def __pretty_repr__(self):
|
597
|
+
yield PrettyType(type=self.__class__.__name__, start='[', end=']')
|
598
|
+
yield PrettyAttr('name', self.name)
|
599
|
+
if self.children:
|
600
|
+
yield PrettyAttr('children', self.children)
|
601
|
+
|
602
|
+
# Create nested structure
|
603
|
+
leaf1 = Container('leaf1')
|
604
|
+
leaf2 = Container('leaf2')
|
605
|
+
branch = Container('branch', [leaf1, leaf2])
|
606
|
+
root = Container('root', [branch])
|
607
|
+
|
608
|
+
result = repr(root)
|
609
|
+
self.assertIn('Container', result)
|
610
|
+
self.assertIn('name=', result)
|
611
|
+
self.assertIn('root', result)
|
612
|
+
self.assertIn('children=', result)
|
613
|
+
|
614
|
+
def test_mixed_types_representation(self):
|
615
|
+
"""Test representation with mixed types."""
|
616
|
+
|
617
|
+
class MixedTypes(PrettyRepr):
|
618
|
+
def __init__(self):
|
619
|
+
self.string = "hello"
|
620
|
+
self.number = 42
|
621
|
+
self.float_num = 3.14
|
622
|
+
self.bool_val = True
|
623
|
+
self.none_val = None
|
624
|
+
self.list_val = [1, 2, 3]
|
625
|
+
self.dict_val = {'key': 'value'}
|
626
|
+
self.tuple_val = (1, 2)
|
627
|
+
self.set_val = {1, 2, 3}
|
628
|
+
|
629
|
+
def __pretty_repr__(self):
|
630
|
+
yield PrettyType(type='MixedTypes')
|
631
|
+
for key, value in vars(self).items():
|
632
|
+
yield PrettyAttr(key, value)
|
633
|
+
|
634
|
+
obj = MixedTypes()
|
635
|
+
result = repr(obj)
|
636
|
+
|
637
|
+
# Check all types are represented correctly
|
638
|
+
# Note: string values passed to PrettyAttr are not re-quoted
|
639
|
+
self.assertIn("string=hello", result)
|
640
|
+
self.assertIn("number=42", result)
|
641
|
+
self.assertIn("float_num=3.14", result)
|
642
|
+
self.assertIn("bool_val=True", result)
|
643
|
+
self.assertIn("none_val=None", result)
|
644
|
+
self.assertIn("list_val=[1, 2, 3]", result)
|
645
|
+
self.assertIn("dict_val={'key': 'value'}", result)
|
646
|
+
|
647
|
+
def test_custom_formatting_styles(self):
|
648
|
+
"""Test various custom formatting styles."""
|
649
|
+
|
650
|
+
class XMLStyle(PrettyRepr):
|
651
|
+
def __init__(self, tag, content):
|
652
|
+
self.tag = tag
|
653
|
+
self.content = content
|
654
|
+
|
655
|
+
def __pretty_repr__(self):
|
656
|
+
yield PrettyType(
|
657
|
+
type='',
|
658
|
+
start=f'<{self.tag}>',
|
659
|
+
end=f'</{self.tag}>',
|
660
|
+
value_sep='',
|
661
|
+
elem_indent='',
|
662
|
+
empty_repr=''
|
663
|
+
)
|
664
|
+
yield PrettyAttr('', self.content)
|
665
|
+
|
666
|
+
obj = XMLStyle('div', 'Hello World')
|
667
|
+
result = repr(obj)
|
668
|
+
self.assertIn('<div>', result)
|
669
|
+
self.assertIn('</div>', result)
|
670
|
+
self.assertIn('Hello World', result)
|
671
|
+
|
672
|
+
def test_unicode_handling(self):
|
673
|
+
"""Test handling of unicode characters."""
|
674
|
+
|
675
|
+
class UnicodeObj(PrettyRepr):
|
676
|
+
def __init__(self):
|
677
|
+
self.emoji = "🎉"
|
678
|
+
self.chinese = "你好"
|
679
|
+
self.special = "café"
|
680
|
+
|
681
|
+
def __pretty_repr__(self):
|
682
|
+
yield PrettyType(type='UnicodeObj')
|
683
|
+
for key, value in vars(self).items():
|
684
|
+
yield PrettyAttr(key, value)
|
685
|
+
|
686
|
+
obj = UnicodeObj()
|
687
|
+
result = repr(obj)
|
688
|
+
|
689
|
+
# Unicode should be preserved (string values are not re-quoted)
|
690
|
+
self.assertIn("emoji=🎉", result)
|
691
|
+
self.assertIn("chinese=你好", result)
|
692
|
+
self.assertIn("special=café", result)
|
693
|
+
|
694
|
+
|
695
|
+
if __name__ == '__main__':
|
696
|
+
unittest.main()
|