brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,696 @@
1
+ # Copyright 2024 BrainState Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Comprehensive tests for pretty_repr module.
17
+ """
18
+
19
+ import dataclasses
20
+ import unittest
21
+ from typing import Iterator, Union
22
+
23
+ from brainstate.util._pretty_repr import (
24
+ PrettyType,
25
+ PrettyAttr,
26
+ PrettyRepr,
27
+ pretty_repr_elem,
28
+ pretty_repr_object,
29
+ MappingReprMixin,
30
+ PrettyMapping,
31
+ PrettyReprContext,
32
+ yield_unique_pretty_repr_items,
33
+ _default_repr_object,
34
+ _default_repr_attr,
35
+ )
36
+
37
+
38
+ class TestPrettyType(unittest.TestCase):
39
+ """Test cases for PrettyType dataclass."""
40
+
41
+ def test_default_values(self):
42
+ """Test PrettyType with default values."""
43
+ pt = PrettyType(type='MyClass')
44
+ self.assertEqual(pt.type, 'MyClass')
45
+ self.assertEqual(pt.start, '(')
46
+ self.assertEqual(pt.end, ')')
47
+ self.assertEqual(pt.value_sep, '=')
48
+ self.assertEqual(pt.elem_indent, ' ')
49
+ self.assertEqual(pt.empty_repr, '')
50
+
51
+ def test_custom_values(self):
52
+ """Test PrettyType with custom values."""
53
+ pt = PrettyType(
54
+ type=dict,
55
+ start='{',
56
+ end='}',
57
+ value_sep=': ',
58
+ elem_indent=' ',
59
+ empty_repr='<empty>'
60
+ )
61
+ self.assertEqual(pt.type, dict)
62
+ self.assertEqual(pt.start, '{')
63
+ self.assertEqual(pt.end, '}')
64
+ self.assertEqual(pt.value_sep, ': ')
65
+ self.assertEqual(pt.elem_indent, ' ')
66
+ self.assertEqual(pt.empty_repr, '<empty>')
67
+
68
+ def test_type_can_be_string_or_class(self):
69
+ """Test that type can be either string or class."""
70
+ pt1 = PrettyType(type='StringType')
71
+ self.assertIsInstance(pt1.type, str)
72
+
73
+ pt2 = PrettyType(type=list)
74
+ self.assertEqual(pt2.type, list)
75
+
76
+
77
+ class TestPrettyAttr(unittest.TestCase):
78
+ """Test cases for PrettyAttr dataclass."""
79
+
80
+ def test_default_values(self):
81
+ """Test PrettyAttr with default values."""
82
+ pa = PrettyAttr(key='name', value='test')
83
+ self.assertEqual(pa.key, 'name')
84
+ self.assertEqual(pa.value, 'test')
85
+ self.assertEqual(pa.start, '')
86
+ self.assertEqual(pa.end, '')
87
+
88
+ def test_custom_values(self):
89
+ """Test PrettyAttr with custom values."""
90
+ pa = PrettyAttr(key='count', value=42, start='[', end=']')
91
+ self.assertEqual(pa.key, 'count')
92
+ self.assertEqual(pa.value, 42)
93
+ self.assertEqual(pa.start, '[')
94
+ self.assertEqual(pa.end, ']')
95
+
96
+ def test_value_types(self):
97
+ """Test PrettyAttr with various value types."""
98
+ pa1 = PrettyAttr('str_value', 'string')
99
+ self.assertEqual(pa1.value, 'string')
100
+
101
+ pa2 = PrettyAttr('int_value', 123)
102
+ self.assertEqual(pa2.value, 123)
103
+
104
+ pa3 = PrettyAttr('list_value', [1, 2, 3])
105
+ self.assertEqual(pa3.value, [1, 2, 3])
106
+
107
+ pa4 = PrettyAttr('dict_value', {'a': 1})
108
+ self.assertEqual(pa4.value, {'a': 1})
109
+
110
+
111
+ class SimplePrettyRepr(PrettyRepr):
112
+ """Simple implementation of PrettyRepr for testing."""
113
+
114
+ def __init__(self, value, name='SimpleObject'):
115
+ self.value = value
116
+ self.name = name
117
+
118
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
119
+ yield PrettyType(type=self.name)
120
+ yield PrettyAttr('value', self.value)
121
+
122
+
123
+ class TestPrettyRepr(unittest.TestCase):
124
+ """Test cases for PrettyRepr abstract class."""
125
+
126
+ def test_simple_repr(self):
127
+ """Test simple pretty representation."""
128
+ obj = SimplePrettyRepr(42)
129
+ result = repr(obj)
130
+ self.assertIn('SimpleObject', result)
131
+ self.assertIn('value=42', result)
132
+
133
+ def test_custom_type_config(self):
134
+ """Test PrettyRepr with custom type configuration."""
135
+
136
+ class CustomRepr(PrettyRepr):
137
+ def __init__(self, data):
138
+ self.data = data
139
+
140
+ def __pretty_repr__(self):
141
+ yield PrettyType(type='CustomObject', start='<', end='>', value_sep=' -> ')
142
+ yield PrettyAttr('data', self.data)
143
+
144
+ obj = CustomRepr({'key': 'value'})
145
+ result = repr(obj)
146
+ self.assertIn('CustomObject<', result)
147
+ self.assertIn('data -> ', result)
148
+ self.assertIn('>', result)
149
+
150
+ def test_multiple_attributes(self):
151
+ """Test PrettyRepr with multiple attributes."""
152
+
153
+ class MultiAttrRepr(PrettyRepr):
154
+ def __init__(self, a, b, c):
155
+ self.a = a
156
+ self.b = b
157
+ self.c = c
158
+
159
+ def __pretty_repr__(self):
160
+ yield PrettyType(type=self.__class__)
161
+ yield PrettyAttr('a', self.a)
162
+ yield PrettyAttr('b', self.b)
163
+ yield PrettyAttr('c', self.c)
164
+
165
+ obj = MultiAttrRepr(1, 'two', [3])
166
+ result = repr(obj)
167
+ self.assertIn('MultiAttrRepr', result)
168
+ self.assertIn('a=1', result)
169
+ self.assertIn("b=two", result) # String value is not re-quoted
170
+ self.assertIn('c=[3]', result)
171
+
172
+ def test_empty_object(self):
173
+ """Test PrettyRepr with no attributes."""
174
+
175
+ class EmptyRepr(PrettyRepr):
176
+ def __pretty_repr__(self):
177
+ yield PrettyType(type='EmptyObject', empty_repr='<no data>')
178
+
179
+ obj = EmptyRepr()
180
+ result = repr(obj)
181
+ self.assertIn('EmptyObject', result)
182
+ self.assertIn('<no data>', result)
183
+
184
+
185
+ class TestPrettyReprElem(unittest.TestCase):
186
+ """Test cases for pretty_repr_elem function."""
187
+
188
+ def test_basic_elem(self):
189
+ """Test basic element formatting."""
190
+ pt = PrettyType(type='Test')
191
+ elem = PrettyAttr('key', 'value')
192
+ result = pretty_repr_elem(pt, elem)
193
+ # Value is already a string, so it's not quoted again
194
+ self.assertEqual(result, " key=value")
195
+
196
+ def test_elem_with_custom_indent(self):
197
+ """Test element with custom indentation."""
198
+ pt = PrettyType(type='Test', elem_indent=' ')
199
+ elem = PrettyAttr('key', 123)
200
+ result = pretty_repr_elem(pt, elem)
201
+ self.assertEqual(result, " key=123")
202
+
203
+ def test_elem_with_custom_separator(self):
204
+ """Test element with custom value separator."""
205
+ pt = PrettyType(type='Test', value_sep=': ')
206
+ elem = PrettyAttr('key', 'value')
207
+ result = pretty_repr_elem(pt, elem)
208
+ self.assertEqual(result, " key: value")
209
+
210
+ def test_elem_with_start_end(self):
211
+ """Test element with start and end markers."""
212
+ pt = PrettyType(type='Test')
213
+ elem = PrettyAttr('key', 'value', start='[', end=']')
214
+ result = pretty_repr_elem(pt, elem)
215
+ self.assertEqual(result, " [key=value]")
216
+
217
+ def test_elem_with_multiline_value(self):
218
+ """Test element with multiline value."""
219
+ pt = PrettyType(type='Test')
220
+ elem = PrettyAttr('key', 'line1\nline2\nline3')
221
+ result = pretty_repr_elem(pt, elem)
222
+ expected = " key=line1\n line2\n line3"
223
+ self.assertEqual(result, expected)
224
+
225
+ def test_elem_invalid_type(self):
226
+ """Test that non-PrettyAttr raises TypeError."""
227
+ pt = PrettyType(type='Test')
228
+ with self.assertRaises(TypeError) as cm:
229
+ pretty_repr_elem(pt, "not a PrettyAttr")
230
+ self.assertIn("Item must be Elem", str(cm.exception))
231
+
232
+
233
+ class TestPrettyReprObject(unittest.TestCase):
234
+ """Test cases for pretty_repr_object function."""
235
+
236
+ def test_valid_object(self):
237
+ """Test with valid PrettyRepr object."""
238
+ obj = SimplePrettyRepr(42, 'TestObject')
239
+ result = pretty_repr_object(obj)
240
+ self.assertIn('TestObject', result)
241
+ self.assertIn('value=42', result)
242
+
243
+ def test_invalid_object(self):
244
+ """Test that non-PrettyRepr object raises TypeError."""
245
+ with self.assertRaises(TypeError) as cm:
246
+ pretty_repr_object("not a PrettyRepr")
247
+ self.assertIn("is not representable", str(cm.exception))
248
+
249
+ def test_invalid_first_item(self):
250
+ """Test that invalid first item raises TypeError."""
251
+
252
+ class InvalidRepr(PrettyRepr):
253
+ def __pretty_repr__(self):
254
+ yield PrettyAttr('key', 'value') # Should yield PrettyType first
255
+
256
+ obj = InvalidRepr()
257
+ with self.assertRaises(TypeError) as cm:
258
+ pretty_repr_object(obj)
259
+ self.assertIn("First item must be PrettyType", str(cm.exception))
260
+
261
+ def test_empty_representation(self):
262
+ """Test object with no attributes."""
263
+
264
+ class EmptyRepr(PrettyRepr):
265
+ def __pretty_repr__(self):
266
+ yield PrettyType(type='Empty', empty_repr='∅')
267
+
268
+ obj = EmptyRepr()
269
+ result = pretty_repr_object(obj)
270
+ self.assertEqual(result, 'Empty(∅)')
271
+
272
+ def test_complex_nested_formatting(self):
273
+ """Test complex nested formatting."""
274
+
275
+ class ComplexRepr(PrettyRepr):
276
+ def __pretty_repr__(self):
277
+ yield PrettyType(
278
+ type='Complex',
279
+ start='{\n',
280
+ end='\n}',
281
+ elem_indent=' ',
282
+ value_sep=' => '
283
+ )
284
+ yield PrettyAttr('first', 'value1')
285
+ yield PrettyAttr('second', {'nested': 'dict'})
286
+
287
+ obj = ComplexRepr()
288
+ result = pretty_repr_object(obj)
289
+ self.assertIn('Complex', result)
290
+ self.assertIn('first => ', result)
291
+ self.assertIn('second => ', result)
292
+
293
+
294
+ class TestMappingReprMixin(unittest.TestCase):
295
+ """Test cases for MappingReprMixin."""
296
+
297
+ def test_basic_mapping(self):
298
+ """Test basic mapping representation."""
299
+
300
+ class MyMapping(dict, MappingReprMixin):
301
+ pass
302
+
303
+ m = MyMapping({'a': 1, 'b': 2})
304
+ # Get the pretty repr items - MappingReprMixin only provides __pretty_repr__
305
+ # but needs to be mixed with a dict-like class
306
+ items = list(m.__pretty_repr__())
307
+
308
+ # Check first item is PrettyType
309
+ self.assertIsInstance(items[0], PrettyType)
310
+ self.assertEqual(items[0].value_sep, ': ')
311
+ self.assertEqual(items[0].start, '{')
312
+ self.assertEqual(items[0].end, '}')
313
+
314
+ # Check that we have the expected number of items
315
+ self.assertEqual(len(items), 3) # PrettyType + 2 attrs
316
+
317
+ # Check attributes
318
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
319
+ self.assertEqual(len(attr_items), 2)
320
+
321
+ # Keys should be repr'd (with quotes for strings)
322
+ keys = [item.key for item in attr_items]
323
+ self.assertIn("'a'", keys)
324
+ self.assertIn("'b'", keys)
325
+
326
+ def test_empty_mapping(self):
327
+ """Test empty mapping representation."""
328
+
329
+ class MyMapping(dict, MappingReprMixin):
330
+ pass
331
+
332
+ m = MyMapping()
333
+ items = list(m.__pretty_repr__())
334
+ self.assertEqual(len(items), 1) # Only PrettyType, no attributes
335
+
336
+
337
+ class TestPrettyMapping(unittest.TestCase):
338
+ """Test cases for PrettyMapping class."""
339
+
340
+ def test_basic_pretty_mapping(self):
341
+ """Test basic PrettyMapping."""
342
+ pm = PrettyMapping({'x': 10, 'y': 20})
343
+ result = repr(pm)
344
+ self.assertIn("'x': 10", result)
345
+ self.assertIn("'y': 20", result)
346
+
347
+ def test_pretty_mapping_with_type_name(self):
348
+ """Test PrettyMapping with custom type name."""
349
+ pm = PrettyMapping({'a': 1}, type_name='MyDict')
350
+ result = repr(pm)
351
+ self.assertIn('MyDict', result)
352
+ self.assertIn("'a': 1", result)
353
+
354
+ def test_empty_pretty_mapping(self):
355
+ """Test empty PrettyMapping."""
356
+ pm = PrettyMapping({})
357
+ result = repr(pm)
358
+ self.assertIn('{', result)
359
+ self.assertIn('}', result)
360
+
361
+ def test_nested_mapping(self):
362
+ """Test PrettyMapping with nested values."""
363
+ pm = PrettyMapping({
364
+ 'simple': 42,
365
+ 'nested': {'inner': 'value'},
366
+ 'list': [1, 2, 3]
367
+ })
368
+ result = repr(pm)
369
+ self.assertIn("'simple': 42", result)
370
+ self.assertIn("'nested':", result)
371
+ self.assertIn("'list': [1, 2, 3]", result)
372
+
373
+
374
+ class TestPrettyReprContext(unittest.TestCase):
375
+ """Test cases for PrettyReprContext."""
376
+
377
+ def test_initial_state(self):
378
+ """Test initial state of context."""
379
+ ctx = PrettyReprContext()
380
+ self.assertIsNone(ctx.seen_modules_repr)
381
+
382
+ def test_thread_local_behavior(self):
383
+ """Test that context is thread-local."""
384
+ ctx1 = PrettyReprContext()
385
+ ctx2 = PrettyReprContext()
386
+
387
+ ctx1.seen_modules_repr = {'test': 1}
388
+ self.assertIsNone(ctx2.seen_modules_repr)
389
+
390
+
391
+ class TestYieldUniquePrettyReprItems(unittest.TestCase):
392
+ """Test cases for yield_unique_pretty_repr_items function."""
393
+
394
+ def test_basic_usage(self):
395
+ """Test basic usage with simple object."""
396
+
397
+ @dataclasses.dataclass
398
+ class SimpleObject:
399
+ value: int
400
+
401
+ obj = SimpleObject(42)
402
+ items = list(yield_unique_pretty_repr_items(obj))
403
+
404
+ # Should yield PrettyType first
405
+ self.assertIsInstance(items[0], PrettyType)
406
+ self.assertEqual(items[0].type, SimpleObject)
407
+
408
+ # Should yield attribute
409
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
410
+ self.assertTrue(any(item.key == 'value' for item in attr_items))
411
+
412
+ def test_custom_repr_functions(self):
413
+ """Test with custom repr functions."""
414
+
415
+ def custom_repr_object(node):
416
+ yield PrettyType(type='CustomType', start='<', end='>')
417
+
418
+ def custom_repr_attr(node):
419
+ yield PrettyAttr('custom_attr', 'custom_value')
420
+
421
+ obj = object()
422
+ items = list(yield_unique_pretty_repr_items(
423
+ obj,
424
+ repr_object=custom_repr_object,
425
+ repr_attr=custom_repr_attr
426
+ ))
427
+
428
+ self.assertIsInstance(items[0], PrettyType)
429
+ self.assertEqual(items[0].type, 'CustomType')
430
+
431
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
432
+ self.assertTrue(any(item.key == 'custom_attr' for item in attr_items))
433
+
434
+ def test_circular_reference_handling(self):
435
+ """Test handling of circular references."""
436
+
437
+ class Node:
438
+ def __init__(self, value):
439
+ self.value = value
440
+ self.next = None
441
+
442
+ # Create circular reference
443
+ node1 = Node(1)
444
+ node2 = Node(2)
445
+ node1.next = node2
446
+ node2.next = node1
447
+
448
+ # Test that within same context, circular reference is detected
449
+ from brainstate.util._pretty_repr import CONTEXT
450
+
451
+ # Clean start
452
+ CONTEXT.seen_modules_repr = None
453
+
454
+ # Set up context to track seen objects
455
+ CONTEXT.seen_modules_repr = {}
456
+
457
+ # First pass - node1 will be added to seen
458
+ items1 = list(yield_unique_pretty_repr_items(node1))
459
+
460
+ # Second pass - should detect node1 is already seen
461
+ items2 = list(yield_unique_pretty_repr_items(node1))
462
+
463
+ # Check that second pass detected circular reference
464
+ type_items = [item for item in items2 if isinstance(item, PrettyType)]
465
+ self.assertTrue(len(type_items) > 0)
466
+ self.assertTrue(any(item.empty_repr == '...' for item in type_items))
467
+
468
+ # Clean up
469
+ CONTEXT.seen_modules_repr = None
470
+
471
+ def test_context_cleanup(self):
472
+ """Test that context is properly cleaned up."""
473
+ from brainstate.util._pretty_repr import CONTEXT
474
+
475
+ # Clean up any previous state
476
+ CONTEXT.seen_modules_repr = None
477
+
478
+ # Use a class instance that has __dict__
479
+ class TestObj:
480
+ def __init__(self):
481
+ self.test = 'value'
482
+
483
+ obj = TestObj()
484
+ list(yield_unique_pretty_repr_items(obj))
485
+
486
+ # Context should be cleaned up after
487
+ self.assertIsNone(CONTEXT.seen_modules_repr)
488
+
489
+ def test_nested_calls(self):
490
+ """Test nested calls don't recreate context."""
491
+ from brainstate.util._pretty_repr import CONTEXT
492
+
493
+ # Clean up any previous state
494
+ CONTEXT.seen_modules_repr = None
495
+
496
+ class Outer:
497
+ def __init__(self):
498
+ self.inner = Inner()
499
+
500
+ class Inner:
501
+ def __init__(self):
502
+ self.value = 42
503
+
504
+ def repr_outer_attr(node):
505
+ # This will trigger nested yield_unique_pretty_repr_items
506
+ for item in yield_unique_pretty_repr_items(node.inner):
507
+ pass
508
+ yield PrettyAttr('inner', node.inner)
509
+
510
+ outer = Outer()
511
+
512
+ # This should not error and should handle nested calls properly
513
+ items = list(yield_unique_pretty_repr_items(
514
+ outer,
515
+ repr_attr=repr_outer_attr
516
+ ))
517
+
518
+ # Clean up should happen
519
+ self.assertIsNone(CONTEXT.seen_modules_repr)
520
+
521
+
522
+ class TestDefaultReprFunctions(unittest.TestCase):
523
+ """Test cases for default repr functions."""
524
+
525
+ def test_default_repr_object(self):
526
+ """Test _default_repr_object function."""
527
+
528
+ class MyClass:
529
+ pass
530
+
531
+ obj = MyClass()
532
+ items = list(_default_repr_object(obj))
533
+
534
+ self.assertEqual(len(items), 1)
535
+ self.assertIsInstance(items[0], PrettyType)
536
+ self.assertEqual(items[0].type, MyClass)
537
+
538
+ def test_default_repr_attr(self):
539
+ """Test _default_repr_attr function."""
540
+
541
+ class MyClass:
542
+ def __init__(self):
543
+ self.public_attr = 'public'
544
+ self._private_attr = 'private'
545
+ self.__dunder_attr = 'dunder'
546
+ self.number = 42
547
+ self.list_attr = [1, 2, 3]
548
+
549
+ obj = MyClass()
550
+ items = list(_default_repr_attr(obj))
551
+
552
+ # Should include public attributes
553
+ attr_keys = {item.key for item in items}
554
+ self.assertIn('public_attr', attr_keys)
555
+ self.assertIn('number', attr_keys)
556
+ self.assertIn('list_attr', attr_keys)
557
+
558
+ # Should exclude private attributes
559
+ self.assertNotIn('_private_attr', attr_keys)
560
+ self.assertNotIn('__dunder_attr', attr_keys)
561
+
562
+ # Check values are repr'd
563
+ public_item = next(item for item in items if item.key == 'public_attr')
564
+ self.assertEqual(public_item.value, "'public'")
565
+
566
+ number_item = next(item for item in items if item.key == 'number')
567
+ self.assertEqual(number_item.value, '42')
568
+
569
+ def test_default_repr_attr_no_vars(self):
570
+ """Test _default_repr_attr with object that has no __dict__."""
571
+
572
+ class NoVars:
573
+ __slots__ = ('x', 'y')
574
+
575
+ def __init__(self):
576
+ self.x = 1
577
+ self.y = 2
578
+
579
+ obj = NoVars()
580
+ # vars() will raise TypeError for objects without __dict__
581
+ with self.assertRaises(TypeError):
582
+ list(_default_repr_attr(obj))
583
+
584
+
585
+ class TestIntegration(unittest.TestCase):
586
+ """Integration tests for the pretty_repr module."""
587
+
588
+ def test_complex_nested_structure(self):
589
+ """Test complex nested structure representation."""
590
+
591
+ class Container(PrettyRepr):
592
+ def __init__(self, name, children=None):
593
+ self.name = name
594
+ self.children = children or []
595
+
596
+ def __pretty_repr__(self):
597
+ yield PrettyType(type=self.__class__.__name__, start='[', end=']')
598
+ yield PrettyAttr('name', self.name)
599
+ if self.children:
600
+ yield PrettyAttr('children', self.children)
601
+
602
+ # Create nested structure
603
+ leaf1 = Container('leaf1')
604
+ leaf2 = Container('leaf2')
605
+ branch = Container('branch', [leaf1, leaf2])
606
+ root = Container('root', [branch])
607
+
608
+ result = repr(root)
609
+ self.assertIn('Container', result)
610
+ self.assertIn('name=', result)
611
+ self.assertIn('root', result)
612
+ self.assertIn('children=', result)
613
+
614
+ def test_mixed_types_representation(self):
615
+ """Test representation with mixed types."""
616
+
617
+ class MixedTypes(PrettyRepr):
618
+ def __init__(self):
619
+ self.string = "hello"
620
+ self.number = 42
621
+ self.float_num = 3.14
622
+ self.bool_val = True
623
+ self.none_val = None
624
+ self.list_val = [1, 2, 3]
625
+ self.dict_val = {'key': 'value'}
626
+ self.tuple_val = (1, 2)
627
+ self.set_val = {1, 2, 3}
628
+
629
+ def __pretty_repr__(self):
630
+ yield PrettyType(type='MixedTypes')
631
+ for key, value in vars(self).items():
632
+ yield PrettyAttr(key, value)
633
+
634
+ obj = MixedTypes()
635
+ result = repr(obj)
636
+
637
+ # Check all types are represented correctly
638
+ # Note: string values passed to PrettyAttr are not re-quoted
639
+ self.assertIn("string=hello", result)
640
+ self.assertIn("number=42", result)
641
+ self.assertIn("float_num=3.14", result)
642
+ self.assertIn("bool_val=True", result)
643
+ self.assertIn("none_val=None", result)
644
+ self.assertIn("list_val=[1, 2, 3]", result)
645
+ self.assertIn("dict_val={'key': 'value'}", result)
646
+
647
+ def test_custom_formatting_styles(self):
648
+ """Test various custom formatting styles."""
649
+
650
+ class XMLStyle(PrettyRepr):
651
+ def __init__(self, tag, content):
652
+ self.tag = tag
653
+ self.content = content
654
+
655
+ def __pretty_repr__(self):
656
+ yield PrettyType(
657
+ type='',
658
+ start=f'<{self.tag}>',
659
+ end=f'</{self.tag}>',
660
+ value_sep='',
661
+ elem_indent='',
662
+ empty_repr=''
663
+ )
664
+ yield PrettyAttr('', self.content)
665
+
666
+ obj = XMLStyle('div', 'Hello World')
667
+ result = repr(obj)
668
+ self.assertIn('<div>', result)
669
+ self.assertIn('</div>', result)
670
+ self.assertIn('Hello World', result)
671
+
672
+ def test_unicode_handling(self):
673
+ """Test handling of unicode characters."""
674
+
675
+ class UnicodeObj(PrettyRepr):
676
+ def __init__(self):
677
+ self.emoji = "🎉"
678
+ self.chinese = "你好"
679
+ self.special = "café"
680
+
681
+ def __pretty_repr__(self):
682
+ yield PrettyType(type='UnicodeObj')
683
+ for key, value in vars(self).items():
684
+ yield PrettyAttr(key, value)
685
+
686
+ obj = UnicodeObj()
687
+ result = repr(obj)
688
+
689
+ # Unicode should be preserved (string values are not re-quoted)
690
+ self.assertIn("emoji=🎉", result)
691
+ self.assertIn("chinese=你好", result)
692
+ self.assertIn("special=café", result)
693
+
694
+
695
+ if __name__ == '__main__':
696
+ unittest.main()