brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,696 +1,696 @@
1
- # Copyright 2024 BrainState Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Comprehensive tests for pretty_repr module.
17
- """
18
-
19
- import dataclasses
20
- import unittest
21
- from typing import Iterator, Union
22
-
23
- from brainstate.util._pretty_repr import (
24
- PrettyType,
25
- PrettyAttr,
26
- PrettyRepr,
27
- pretty_repr_elem,
28
- pretty_repr_object,
29
- MappingReprMixin,
30
- PrettyMapping,
31
- PrettyReprContext,
32
- yield_unique_pretty_repr_items,
33
- _default_repr_object,
34
- _default_repr_attr,
35
- )
36
-
37
-
38
- class TestPrettyType(unittest.TestCase):
39
- """Test cases for PrettyType dataclass."""
40
-
41
- def test_default_values(self):
42
- """Test PrettyType with default values."""
43
- pt = PrettyType(type='MyClass')
44
- self.assertEqual(pt.type, 'MyClass')
45
- self.assertEqual(pt.start, '(')
46
- self.assertEqual(pt.end, ')')
47
- self.assertEqual(pt.value_sep, '=')
48
- self.assertEqual(pt.elem_indent, ' ')
49
- self.assertEqual(pt.empty_repr, '')
50
-
51
- def test_custom_values(self):
52
- """Test PrettyType with custom values."""
53
- pt = PrettyType(
54
- type=dict,
55
- start='{',
56
- end='}',
57
- value_sep=': ',
58
- elem_indent=' ',
59
- empty_repr='<empty>'
60
- )
61
- self.assertEqual(pt.type, dict)
62
- self.assertEqual(pt.start, '{')
63
- self.assertEqual(pt.end, '}')
64
- self.assertEqual(pt.value_sep, ': ')
65
- self.assertEqual(pt.elem_indent, ' ')
66
- self.assertEqual(pt.empty_repr, '<empty>')
67
-
68
- def test_type_can_be_string_or_class(self):
69
- """Test that type can be either string or class."""
70
- pt1 = PrettyType(type='StringType')
71
- self.assertIsInstance(pt1.type, str)
72
-
73
- pt2 = PrettyType(type=list)
74
- self.assertEqual(pt2.type, list)
75
-
76
-
77
- class TestPrettyAttr(unittest.TestCase):
78
- """Test cases for PrettyAttr dataclass."""
79
-
80
- def test_default_values(self):
81
- """Test PrettyAttr with default values."""
82
- pa = PrettyAttr(key='name', value='test')
83
- self.assertEqual(pa.key, 'name')
84
- self.assertEqual(pa.value, 'test')
85
- self.assertEqual(pa.start, '')
86
- self.assertEqual(pa.end, '')
87
-
88
- def test_custom_values(self):
89
- """Test PrettyAttr with custom values."""
90
- pa = PrettyAttr(key='count', value=42, start='[', end=']')
91
- self.assertEqual(pa.key, 'count')
92
- self.assertEqual(pa.value, 42)
93
- self.assertEqual(pa.start, '[')
94
- self.assertEqual(pa.end, ']')
95
-
96
- def test_value_types(self):
97
- """Test PrettyAttr with various value types."""
98
- pa1 = PrettyAttr('str_value', 'string')
99
- self.assertEqual(pa1.value, 'string')
100
-
101
- pa2 = PrettyAttr('int_value', 123)
102
- self.assertEqual(pa2.value, 123)
103
-
104
- pa3 = PrettyAttr('list_value', [1, 2, 3])
105
- self.assertEqual(pa3.value, [1, 2, 3])
106
-
107
- pa4 = PrettyAttr('dict_value', {'a': 1})
108
- self.assertEqual(pa4.value, {'a': 1})
109
-
110
-
111
- class SimplePrettyRepr(PrettyRepr):
112
- """Simple implementation of PrettyRepr for testing."""
113
-
114
- def __init__(self, value, name='SimpleObject'):
115
- self.value = value
116
- self.name = name
117
-
118
- def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
119
- yield PrettyType(type=self.name)
120
- yield PrettyAttr('value', self.value)
121
-
122
-
123
- class TestPrettyRepr(unittest.TestCase):
124
- """Test cases for PrettyRepr abstract class."""
125
-
126
- def test_simple_repr(self):
127
- """Test simple pretty representation."""
128
- obj = SimplePrettyRepr(42)
129
- result = repr(obj)
130
- self.assertIn('SimpleObject', result)
131
- self.assertIn('value=42', result)
132
-
133
- def test_custom_type_config(self):
134
- """Test PrettyRepr with custom type configuration."""
135
-
136
- class CustomRepr(PrettyRepr):
137
- def __init__(self, data):
138
- self.data = data
139
-
140
- def __pretty_repr__(self):
141
- yield PrettyType(type='CustomObject', start='<', end='>', value_sep=' -> ')
142
- yield PrettyAttr('data', self.data)
143
-
144
- obj = CustomRepr({'key': 'value'})
145
- result = repr(obj)
146
- self.assertIn('CustomObject<', result)
147
- self.assertIn('data -> ', result)
148
- self.assertIn('>', result)
149
-
150
- def test_multiple_attributes(self):
151
- """Test PrettyRepr with multiple attributes."""
152
-
153
- class MultiAttrRepr(PrettyRepr):
154
- def __init__(self, a, b, c):
155
- self.a = a
156
- self.b = b
157
- self.c = c
158
-
159
- def __pretty_repr__(self):
160
- yield PrettyType(type=self.__class__)
161
- yield PrettyAttr('a', self.a)
162
- yield PrettyAttr('b', self.b)
163
- yield PrettyAttr('c', self.c)
164
-
165
- obj = MultiAttrRepr(1, 'two', [3])
166
- result = repr(obj)
167
- self.assertIn('MultiAttrRepr', result)
168
- self.assertIn('a=1', result)
169
- self.assertIn("b=two", result) # String value is not re-quoted
170
- self.assertIn('c=[3]', result)
171
-
172
- def test_empty_object(self):
173
- """Test PrettyRepr with no attributes."""
174
-
175
- class EmptyRepr(PrettyRepr):
176
- def __pretty_repr__(self):
177
- yield PrettyType(type='EmptyObject', empty_repr='<no data>')
178
-
179
- obj = EmptyRepr()
180
- result = repr(obj)
181
- self.assertIn('EmptyObject', result)
182
- self.assertIn('<no data>', result)
183
-
184
-
185
- class TestPrettyReprElem(unittest.TestCase):
186
- """Test cases for pretty_repr_elem function."""
187
-
188
- def test_basic_elem(self):
189
- """Test basic element formatting."""
190
- pt = PrettyType(type='Test')
191
- elem = PrettyAttr('key', 'value')
192
- result = pretty_repr_elem(pt, elem)
193
- # Value is already a string, so it's not quoted again
194
- self.assertEqual(result, " key=value")
195
-
196
- def test_elem_with_custom_indent(self):
197
- """Test element with custom indentation."""
198
- pt = PrettyType(type='Test', elem_indent=' ')
199
- elem = PrettyAttr('key', 123)
200
- result = pretty_repr_elem(pt, elem)
201
- self.assertEqual(result, " key=123")
202
-
203
- def test_elem_with_custom_separator(self):
204
- """Test element with custom value separator."""
205
- pt = PrettyType(type='Test', value_sep=': ')
206
- elem = PrettyAttr('key', 'value')
207
- result = pretty_repr_elem(pt, elem)
208
- self.assertEqual(result, " key: value")
209
-
210
- def test_elem_with_start_end(self):
211
- """Test element with start and end markers."""
212
- pt = PrettyType(type='Test')
213
- elem = PrettyAttr('key', 'value', start='[', end=']')
214
- result = pretty_repr_elem(pt, elem)
215
- self.assertEqual(result, " [key=value]")
216
-
217
- def test_elem_with_multiline_value(self):
218
- """Test element with multiline value."""
219
- pt = PrettyType(type='Test')
220
- elem = PrettyAttr('key', 'line1\nline2\nline3')
221
- result = pretty_repr_elem(pt, elem)
222
- expected = " key=line1\n line2\n line3"
223
- self.assertEqual(result, expected)
224
-
225
- def test_elem_invalid_type(self):
226
- """Test that non-PrettyAttr raises TypeError."""
227
- pt = PrettyType(type='Test')
228
- with self.assertRaises(TypeError) as cm:
229
- pretty_repr_elem(pt, "not a PrettyAttr")
230
- self.assertIn("Item must be Elem", str(cm.exception))
231
-
232
-
233
- class TestPrettyReprObject(unittest.TestCase):
234
- """Test cases for pretty_repr_object function."""
235
-
236
- def test_valid_object(self):
237
- """Test with valid PrettyRepr object."""
238
- obj = SimplePrettyRepr(42, 'TestObject')
239
- result = pretty_repr_object(obj)
240
- self.assertIn('TestObject', result)
241
- self.assertIn('value=42', result)
242
-
243
- def test_invalid_object(self):
244
- """Test that non-PrettyRepr object raises TypeError."""
245
- with self.assertRaises(TypeError) as cm:
246
- pretty_repr_object("not a PrettyRepr")
247
- self.assertIn("is not representable", str(cm.exception))
248
-
249
- def test_invalid_first_item(self):
250
- """Test that invalid first item raises TypeError."""
251
-
252
- class InvalidRepr(PrettyRepr):
253
- def __pretty_repr__(self):
254
- yield PrettyAttr('key', 'value') # Should yield PrettyType first
255
-
256
- obj = InvalidRepr()
257
- with self.assertRaises(TypeError) as cm:
258
- pretty_repr_object(obj)
259
- self.assertIn("First item must be PrettyType", str(cm.exception))
260
-
261
- def test_empty_representation(self):
262
- """Test object with no attributes."""
263
-
264
- class EmptyRepr(PrettyRepr):
265
- def __pretty_repr__(self):
266
- yield PrettyType(type='Empty', empty_repr='∅')
267
-
268
- obj = EmptyRepr()
269
- result = pretty_repr_object(obj)
270
- self.assertEqual(result, 'Empty(∅)')
271
-
272
- def test_complex_nested_formatting(self):
273
- """Test complex nested formatting."""
274
-
275
- class ComplexRepr(PrettyRepr):
276
- def __pretty_repr__(self):
277
- yield PrettyType(
278
- type='Complex',
279
- start='{\n',
280
- end='\n}',
281
- elem_indent=' ',
282
- value_sep=' => '
283
- )
284
- yield PrettyAttr('first', 'value1')
285
- yield PrettyAttr('second', {'nested': 'dict'})
286
-
287
- obj = ComplexRepr()
288
- result = pretty_repr_object(obj)
289
- self.assertIn('Complex', result)
290
- self.assertIn('first => ', result)
291
- self.assertIn('second => ', result)
292
-
293
-
294
- class TestMappingReprMixin(unittest.TestCase):
295
- """Test cases for MappingReprMixin."""
296
-
297
- def test_basic_mapping(self):
298
- """Test basic mapping representation."""
299
-
300
- class MyMapping(dict, MappingReprMixin):
301
- pass
302
-
303
- m = MyMapping({'a': 1, 'b': 2})
304
- # Get the pretty repr items - MappingReprMixin only provides __pretty_repr__
305
- # but needs to be mixed with a dict-like class
306
- items = list(m.__pretty_repr__())
307
-
308
- # Check first item is PrettyType
309
- self.assertIsInstance(items[0], PrettyType)
310
- self.assertEqual(items[0].value_sep, ': ')
311
- self.assertEqual(items[0].start, '{')
312
- self.assertEqual(items[0].end, '}')
313
-
314
- # Check that we have the expected number of items
315
- self.assertEqual(len(items), 3) # PrettyType + 2 attrs
316
-
317
- # Check attributes
318
- attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
319
- self.assertEqual(len(attr_items), 2)
320
-
321
- # Keys should be repr'd (with quotes for strings)
322
- keys = [item.key for item in attr_items]
323
- self.assertIn("'a'", keys)
324
- self.assertIn("'b'", keys)
325
-
326
- def test_empty_mapping(self):
327
- """Test empty mapping representation."""
328
-
329
- class MyMapping(dict, MappingReprMixin):
330
- pass
331
-
332
- m = MyMapping()
333
- items = list(m.__pretty_repr__())
334
- self.assertEqual(len(items), 1) # Only PrettyType, no attributes
335
-
336
-
337
- class TestPrettyMapping(unittest.TestCase):
338
- """Test cases for PrettyMapping class."""
339
-
340
- def test_basic_pretty_mapping(self):
341
- """Test basic PrettyMapping."""
342
- pm = PrettyMapping({'x': 10, 'y': 20})
343
- result = repr(pm)
344
- self.assertIn("'x': 10", result)
345
- self.assertIn("'y': 20", result)
346
-
347
- def test_pretty_mapping_with_type_name(self):
348
- """Test PrettyMapping with custom type name."""
349
- pm = PrettyMapping({'a': 1}, type_name='MyDict')
350
- result = repr(pm)
351
- self.assertIn('MyDict', result)
352
- self.assertIn("'a': 1", result)
353
-
354
- def test_empty_pretty_mapping(self):
355
- """Test empty PrettyMapping."""
356
- pm = PrettyMapping({})
357
- result = repr(pm)
358
- self.assertIn('{', result)
359
- self.assertIn('}', result)
360
-
361
- def test_nested_mapping(self):
362
- """Test PrettyMapping with nested values."""
363
- pm = PrettyMapping({
364
- 'simple': 42,
365
- 'nested': {'inner': 'value'},
366
- 'list': [1, 2, 3]
367
- })
368
- result = repr(pm)
369
- self.assertIn("'simple': 42", result)
370
- self.assertIn("'nested':", result)
371
- self.assertIn("'list': [1, 2, 3]", result)
372
-
373
-
374
- class TestPrettyReprContext(unittest.TestCase):
375
- """Test cases for PrettyReprContext."""
376
-
377
- def test_initial_state(self):
378
- """Test initial state of context."""
379
- ctx = PrettyReprContext()
380
- self.assertIsNone(ctx.seen_modules_repr)
381
-
382
- def test_thread_local_behavior(self):
383
- """Test that context is thread-local."""
384
- ctx1 = PrettyReprContext()
385
- ctx2 = PrettyReprContext()
386
-
387
- ctx1.seen_modules_repr = {'test': 1}
388
- self.assertIsNone(ctx2.seen_modules_repr)
389
-
390
-
391
- class TestYieldUniquePrettyReprItems(unittest.TestCase):
392
- """Test cases for yield_unique_pretty_repr_items function."""
393
-
394
- def test_basic_usage(self):
395
- """Test basic usage with simple object."""
396
-
397
- @dataclasses.dataclass
398
- class SimpleObject:
399
- value: int
400
-
401
- obj = SimpleObject(42)
402
- items = list(yield_unique_pretty_repr_items(obj))
403
-
404
- # Should yield PrettyType first
405
- self.assertIsInstance(items[0], PrettyType)
406
- self.assertEqual(items[0].type, SimpleObject)
407
-
408
- # Should yield attribute
409
- attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
410
- self.assertTrue(any(item.key == 'value' for item in attr_items))
411
-
412
- def test_custom_repr_functions(self):
413
- """Test with custom repr functions."""
414
-
415
- def custom_repr_object(node):
416
- yield PrettyType(type='CustomType', start='<', end='>')
417
-
418
- def custom_repr_attr(node):
419
- yield PrettyAttr('custom_attr', 'custom_value')
420
-
421
- obj = object()
422
- items = list(yield_unique_pretty_repr_items(
423
- obj,
424
- repr_object=custom_repr_object,
425
- repr_attr=custom_repr_attr
426
- ))
427
-
428
- self.assertIsInstance(items[0], PrettyType)
429
- self.assertEqual(items[0].type, 'CustomType')
430
-
431
- attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
432
- self.assertTrue(any(item.key == 'custom_attr' for item in attr_items))
433
-
434
- def test_circular_reference_handling(self):
435
- """Test handling of circular references."""
436
-
437
- class Node:
438
- def __init__(self, value):
439
- self.value = value
440
- self.next = None
441
-
442
- # Create circular reference
443
- node1 = Node(1)
444
- node2 = Node(2)
445
- node1.next = node2
446
- node2.next = node1
447
-
448
- # Test that within same context, circular reference is detected
449
- from brainstate.util._pretty_repr import CONTEXT
450
-
451
- # Clean start
452
- CONTEXT.seen_modules_repr = None
453
-
454
- # Set up context to track seen objects
455
- CONTEXT.seen_modules_repr = {}
456
-
457
- # First pass - node1 will be added to seen
458
- items1 = list(yield_unique_pretty_repr_items(node1))
459
-
460
- # Second pass - should detect node1 is already seen
461
- items2 = list(yield_unique_pretty_repr_items(node1))
462
-
463
- # Check that second pass detected circular reference
464
- type_items = [item for item in items2 if isinstance(item, PrettyType)]
465
- self.assertTrue(len(type_items) > 0)
466
- self.assertTrue(any(item.empty_repr == '...' for item in type_items))
467
-
468
- # Clean up
469
- CONTEXT.seen_modules_repr = None
470
-
471
- def test_context_cleanup(self):
472
- """Test that context is properly cleaned up."""
473
- from brainstate.util._pretty_repr import CONTEXT
474
-
475
- # Clean up any previous state
476
- CONTEXT.seen_modules_repr = None
477
-
478
- # Use a class instance that has __dict__
479
- class TestObj:
480
- def __init__(self):
481
- self.test = 'value'
482
-
483
- obj = TestObj()
484
- list(yield_unique_pretty_repr_items(obj))
485
-
486
- # Context should be cleaned up after
487
- self.assertIsNone(CONTEXT.seen_modules_repr)
488
-
489
- def test_nested_calls(self):
490
- """Test nested calls don't recreate context."""
491
- from brainstate.util._pretty_repr import CONTEXT
492
-
493
- # Clean up any previous state
494
- CONTEXT.seen_modules_repr = None
495
-
496
- class Outer:
497
- def __init__(self):
498
- self.inner = Inner()
499
-
500
- class Inner:
501
- def __init__(self):
502
- self.value = 42
503
-
504
- def repr_outer_attr(node):
505
- # This will trigger nested yield_unique_pretty_repr_items
506
- for item in yield_unique_pretty_repr_items(node.inner):
507
- pass
508
- yield PrettyAttr('inner', node.inner)
509
-
510
- outer = Outer()
511
-
512
- # This should not error and should handle nested calls properly
513
- items = list(yield_unique_pretty_repr_items(
514
- outer,
515
- repr_attr=repr_outer_attr
516
- ))
517
-
518
- # Clean up should happen
519
- self.assertIsNone(CONTEXT.seen_modules_repr)
520
-
521
-
522
- class TestDefaultReprFunctions(unittest.TestCase):
523
- """Test cases for default repr functions."""
524
-
525
- def test_default_repr_object(self):
526
- """Test _default_repr_object function."""
527
-
528
- class MyClass:
529
- pass
530
-
531
- obj = MyClass()
532
- items = list(_default_repr_object(obj))
533
-
534
- self.assertEqual(len(items), 1)
535
- self.assertIsInstance(items[0], PrettyType)
536
- self.assertEqual(items[0].type, MyClass)
537
-
538
- def test_default_repr_attr(self):
539
- """Test _default_repr_attr function."""
540
-
541
- class MyClass:
542
- def __init__(self):
543
- self.public_attr = 'public'
544
- self._private_attr = 'private'
545
- self.__dunder_attr = 'dunder'
546
- self.number = 42
547
- self.list_attr = [1, 2, 3]
548
-
549
- obj = MyClass()
550
- items = list(_default_repr_attr(obj))
551
-
552
- # Should include public attributes
553
- attr_keys = {item.key for item in items}
554
- self.assertIn('public_attr', attr_keys)
555
- self.assertIn('number', attr_keys)
556
- self.assertIn('list_attr', attr_keys)
557
-
558
- # Should exclude private attributes
559
- self.assertNotIn('_private_attr', attr_keys)
560
- self.assertNotIn('__dunder_attr', attr_keys)
561
-
562
- # Check values are repr'd
563
- public_item = next(item for item in items if item.key == 'public_attr')
564
- self.assertEqual(public_item.value, "'public'")
565
-
566
- number_item = next(item for item in items if item.key == 'number')
567
- self.assertEqual(number_item.value, '42')
568
-
569
- def test_default_repr_attr_no_vars(self):
570
- """Test _default_repr_attr with object that has no __dict__."""
571
-
572
- class NoVars:
573
- __slots__ = ('x', 'y')
574
-
575
- def __init__(self):
576
- self.x = 1
577
- self.y = 2
578
-
579
- obj = NoVars()
580
- # vars() will raise TypeError for objects without __dict__
581
- with self.assertRaises(TypeError):
582
- list(_default_repr_attr(obj))
583
-
584
-
585
- class TestIntegration(unittest.TestCase):
586
- """Integration tests for the pretty_repr module."""
587
-
588
- def test_complex_nested_structure(self):
589
- """Test complex nested structure representation."""
590
-
591
- class Container(PrettyRepr):
592
- def __init__(self, name, children=None):
593
- self.name = name
594
- self.children = children or []
595
-
596
- def __pretty_repr__(self):
597
- yield PrettyType(type=self.__class__.__name__, start='[', end=']')
598
- yield PrettyAttr('name', self.name)
599
- if self.children:
600
- yield PrettyAttr('children', self.children)
601
-
602
- # Create nested structure
603
- leaf1 = Container('leaf1')
604
- leaf2 = Container('leaf2')
605
- branch = Container('branch', [leaf1, leaf2])
606
- root = Container('root', [branch])
607
-
608
- result = repr(root)
609
- self.assertIn('Container', result)
610
- self.assertIn('name=', result)
611
- self.assertIn('root', result)
612
- self.assertIn('children=', result)
613
-
614
- def test_mixed_types_representation(self):
615
- """Test representation with mixed types."""
616
-
617
- class MixedTypes(PrettyRepr):
618
- def __init__(self):
619
- self.string = "hello"
620
- self.number = 42
621
- self.float_num = 3.14
622
- self.bool_val = True
623
- self.none_val = None
624
- self.list_val = [1, 2, 3]
625
- self.dict_val = {'key': 'value'}
626
- self.tuple_val = (1, 2)
627
- self.set_val = {1, 2, 3}
628
-
629
- def __pretty_repr__(self):
630
- yield PrettyType(type='MixedTypes')
631
- for key, value in vars(self).items():
632
- yield PrettyAttr(key, value)
633
-
634
- obj = MixedTypes()
635
- result = repr(obj)
636
-
637
- # Check all types are represented correctly
638
- # Note: string values passed to PrettyAttr are not re-quoted
639
- self.assertIn("string=hello", result)
640
- self.assertIn("number=42", result)
641
- self.assertIn("float_num=3.14", result)
642
- self.assertIn("bool_val=True", result)
643
- self.assertIn("none_val=None", result)
644
- self.assertIn("list_val=[1, 2, 3]", result)
645
- self.assertIn("dict_val={'key': 'value'}", result)
646
-
647
- def test_custom_formatting_styles(self):
648
- """Test various custom formatting styles."""
649
-
650
- class XMLStyle(PrettyRepr):
651
- def __init__(self, tag, content):
652
- self.tag = tag
653
- self.content = content
654
-
655
- def __pretty_repr__(self):
656
- yield PrettyType(
657
- type='',
658
- start=f'<{self.tag}>',
659
- end=f'</{self.tag}>',
660
- value_sep='',
661
- elem_indent='',
662
- empty_repr=''
663
- )
664
- yield PrettyAttr('', self.content)
665
-
666
- obj = XMLStyle('div', 'Hello World')
667
- result = repr(obj)
668
- self.assertIn('<div>', result)
669
- self.assertIn('</div>', result)
670
- self.assertIn('Hello World', result)
671
-
672
- def test_unicode_handling(self):
673
- """Test handling of unicode characters."""
674
-
675
- class UnicodeObj(PrettyRepr):
676
- def __init__(self):
677
- self.emoji = "🎉"
678
- self.chinese = "你好"
679
- self.special = "café"
680
-
681
- def __pretty_repr__(self):
682
- yield PrettyType(type='UnicodeObj')
683
- for key, value in vars(self).items():
684
- yield PrettyAttr(key, value)
685
-
686
- obj = UnicodeObj()
687
- result = repr(obj)
688
-
689
- # Unicode should be preserved (string values are not re-quoted)
690
- self.assertIn("emoji=🎉", result)
691
- self.assertIn("chinese=你好", result)
692
- self.assertIn("special=café", result)
693
-
694
-
695
- if __name__ == '__main__':
696
- unittest.main()
1
+ # Copyright 2024 BrainState Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Comprehensive tests for pretty_repr module.
17
+ """
18
+
19
+ import dataclasses
20
+ import unittest
21
+ from typing import Iterator, Union
22
+
23
+ from brainstate.util._pretty_repr import (
24
+ PrettyType,
25
+ PrettyAttr,
26
+ PrettyRepr,
27
+ pretty_repr_elem,
28
+ pretty_repr_object,
29
+ MappingReprMixin,
30
+ PrettyMapping,
31
+ PrettyReprContext,
32
+ yield_unique_pretty_repr_items,
33
+ _default_repr_object,
34
+ _default_repr_attr,
35
+ )
36
+
37
+
38
+ class TestPrettyType(unittest.TestCase):
39
+ """Test cases for PrettyType dataclass."""
40
+
41
+ def test_default_values(self):
42
+ """Test PrettyType with default values."""
43
+ pt = PrettyType(type='MyClass')
44
+ self.assertEqual(pt.type, 'MyClass')
45
+ self.assertEqual(pt.start, '(')
46
+ self.assertEqual(pt.end, ')')
47
+ self.assertEqual(pt.value_sep, '=')
48
+ self.assertEqual(pt.elem_indent, ' ')
49
+ self.assertEqual(pt.empty_repr, '')
50
+
51
+ def test_custom_values(self):
52
+ """Test PrettyType with custom values."""
53
+ pt = PrettyType(
54
+ type=dict,
55
+ start='{',
56
+ end='}',
57
+ value_sep=': ',
58
+ elem_indent=' ',
59
+ empty_repr='<empty>'
60
+ )
61
+ self.assertEqual(pt.type, dict)
62
+ self.assertEqual(pt.start, '{')
63
+ self.assertEqual(pt.end, '}')
64
+ self.assertEqual(pt.value_sep, ': ')
65
+ self.assertEqual(pt.elem_indent, ' ')
66
+ self.assertEqual(pt.empty_repr, '<empty>')
67
+
68
+ def test_type_can_be_string_or_class(self):
69
+ """Test that type can be either string or class."""
70
+ pt1 = PrettyType(type='StringType')
71
+ self.assertIsInstance(pt1.type, str)
72
+
73
+ pt2 = PrettyType(type=list)
74
+ self.assertEqual(pt2.type, list)
75
+
76
+
77
+ class TestPrettyAttr(unittest.TestCase):
78
+ """Test cases for PrettyAttr dataclass."""
79
+
80
+ def test_default_values(self):
81
+ """Test PrettyAttr with default values."""
82
+ pa = PrettyAttr(key='name', value='test')
83
+ self.assertEqual(pa.key, 'name')
84
+ self.assertEqual(pa.value, 'test')
85
+ self.assertEqual(pa.start, '')
86
+ self.assertEqual(pa.end, '')
87
+
88
+ def test_custom_values(self):
89
+ """Test PrettyAttr with custom values."""
90
+ pa = PrettyAttr(key='count', value=42, start='[', end=']')
91
+ self.assertEqual(pa.key, 'count')
92
+ self.assertEqual(pa.value, 42)
93
+ self.assertEqual(pa.start, '[')
94
+ self.assertEqual(pa.end, ']')
95
+
96
+ def test_value_types(self):
97
+ """Test PrettyAttr with various value types."""
98
+ pa1 = PrettyAttr('str_value', 'string')
99
+ self.assertEqual(pa1.value, 'string')
100
+
101
+ pa2 = PrettyAttr('int_value', 123)
102
+ self.assertEqual(pa2.value, 123)
103
+
104
+ pa3 = PrettyAttr('list_value', [1, 2, 3])
105
+ self.assertEqual(pa3.value, [1, 2, 3])
106
+
107
+ pa4 = PrettyAttr('dict_value', {'a': 1})
108
+ self.assertEqual(pa4.value, {'a': 1})
109
+
110
+
111
+ class SimplePrettyRepr(PrettyRepr):
112
+ """Simple implementation of PrettyRepr for testing."""
113
+
114
+ def __init__(self, value, name='SimpleObject'):
115
+ self.value = value
116
+ self.name = name
117
+
118
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
119
+ yield PrettyType(type=self.name)
120
+ yield PrettyAttr('value', self.value)
121
+
122
+
123
+ class TestPrettyRepr(unittest.TestCase):
124
+ """Test cases for PrettyRepr abstract class."""
125
+
126
+ def test_simple_repr(self):
127
+ """Test simple pretty representation."""
128
+ obj = SimplePrettyRepr(42)
129
+ result = repr(obj)
130
+ self.assertIn('SimpleObject', result)
131
+ self.assertIn('value=42', result)
132
+
133
+ def test_custom_type_config(self):
134
+ """Test PrettyRepr with custom type configuration."""
135
+
136
+ class CustomRepr(PrettyRepr):
137
+ def __init__(self, data):
138
+ self.data = data
139
+
140
+ def __pretty_repr__(self):
141
+ yield PrettyType(type='CustomObject', start='<', end='>', value_sep=' -> ')
142
+ yield PrettyAttr('data', self.data)
143
+
144
+ obj = CustomRepr({'key': 'value'})
145
+ result = repr(obj)
146
+ self.assertIn('CustomObject<', result)
147
+ self.assertIn('data -> ', result)
148
+ self.assertIn('>', result)
149
+
150
+ def test_multiple_attributes(self):
151
+ """Test PrettyRepr with multiple attributes."""
152
+
153
+ class MultiAttrRepr(PrettyRepr):
154
+ def __init__(self, a, b, c):
155
+ self.a = a
156
+ self.b = b
157
+ self.c = c
158
+
159
+ def __pretty_repr__(self):
160
+ yield PrettyType(type=self.__class__)
161
+ yield PrettyAttr('a', self.a)
162
+ yield PrettyAttr('b', self.b)
163
+ yield PrettyAttr('c', self.c)
164
+
165
+ obj = MultiAttrRepr(1, 'two', [3])
166
+ result = repr(obj)
167
+ self.assertIn('MultiAttrRepr', result)
168
+ self.assertIn('a=1', result)
169
+ self.assertIn("b=two", result) # String value is not re-quoted
170
+ self.assertIn('c=[3]', result)
171
+
172
+ def test_empty_object(self):
173
+ """Test PrettyRepr with no attributes."""
174
+
175
+ class EmptyRepr(PrettyRepr):
176
+ def __pretty_repr__(self):
177
+ yield PrettyType(type='EmptyObject', empty_repr='<no data>')
178
+
179
+ obj = EmptyRepr()
180
+ result = repr(obj)
181
+ self.assertIn('EmptyObject', result)
182
+ self.assertIn('<no data>', result)
183
+
184
+
185
+ class TestPrettyReprElem(unittest.TestCase):
186
+ """Test cases for pretty_repr_elem function."""
187
+
188
+ def test_basic_elem(self):
189
+ """Test basic element formatting."""
190
+ pt = PrettyType(type='Test')
191
+ elem = PrettyAttr('key', 'value')
192
+ result = pretty_repr_elem(pt, elem)
193
+ # Value is already a string, so it's not quoted again
194
+ self.assertEqual(result, " key=value")
195
+
196
+ def test_elem_with_custom_indent(self):
197
+ """Test element with custom indentation."""
198
+ pt = PrettyType(type='Test', elem_indent=' ')
199
+ elem = PrettyAttr('key', 123)
200
+ result = pretty_repr_elem(pt, elem)
201
+ self.assertEqual(result, " key=123")
202
+
203
+ def test_elem_with_custom_separator(self):
204
+ """Test element with custom value separator."""
205
+ pt = PrettyType(type='Test', value_sep=': ')
206
+ elem = PrettyAttr('key', 'value')
207
+ result = pretty_repr_elem(pt, elem)
208
+ self.assertEqual(result, " key: value")
209
+
210
+ def test_elem_with_start_end(self):
211
+ """Test element with start and end markers."""
212
+ pt = PrettyType(type='Test')
213
+ elem = PrettyAttr('key', 'value', start='[', end=']')
214
+ result = pretty_repr_elem(pt, elem)
215
+ self.assertEqual(result, " [key=value]")
216
+
217
+ def test_elem_with_multiline_value(self):
218
+ """Test element with multiline value."""
219
+ pt = PrettyType(type='Test')
220
+ elem = PrettyAttr('key', 'line1\nline2\nline3')
221
+ result = pretty_repr_elem(pt, elem)
222
+ expected = " key=line1\n line2\n line3"
223
+ self.assertEqual(result, expected)
224
+
225
+ def test_elem_invalid_type(self):
226
+ """Test that non-PrettyAttr raises TypeError."""
227
+ pt = PrettyType(type='Test')
228
+ with self.assertRaises(TypeError) as cm:
229
+ pretty_repr_elem(pt, "not a PrettyAttr")
230
+ self.assertIn("Item must be Elem", str(cm.exception))
231
+
232
+
233
+ class TestPrettyReprObject(unittest.TestCase):
234
+ """Test cases for pretty_repr_object function."""
235
+
236
+ def test_valid_object(self):
237
+ """Test with valid PrettyRepr object."""
238
+ obj = SimplePrettyRepr(42, 'TestObject')
239
+ result = pretty_repr_object(obj)
240
+ self.assertIn('TestObject', result)
241
+ self.assertIn('value=42', result)
242
+
243
+ def test_invalid_object(self):
244
+ """Test that non-PrettyRepr object raises TypeError."""
245
+ with self.assertRaises(TypeError) as cm:
246
+ pretty_repr_object("not a PrettyRepr")
247
+ self.assertIn("is not representable", str(cm.exception))
248
+
249
+ def test_invalid_first_item(self):
250
+ """Test that invalid first item raises TypeError."""
251
+
252
+ class InvalidRepr(PrettyRepr):
253
+ def __pretty_repr__(self):
254
+ yield PrettyAttr('key', 'value') # Should yield PrettyType first
255
+
256
+ obj = InvalidRepr()
257
+ with self.assertRaises(TypeError) as cm:
258
+ pretty_repr_object(obj)
259
+ self.assertIn("First item must be PrettyType", str(cm.exception))
260
+
261
+ def test_empty_representation(self):
262
+ """Test object with no attributes."""
263
+
264
+ class EmptyRepr(PrettyRepr):
265
+ def __pretty_repr__(self):
266
+ yield PrettyType(type='Empty', empty_repr='∅')
267
+
268
+ obj = EmptyRepr()
269
+ result = pretty_repr_object(obj)
270
+ self.assertEqual(result, 'Empty(∅)')
271
+
272
+ def test_complex_nested_formatting(self):
273
+ """Test complex nested formatting."""
274
+
275
+ class ComplexRepr(PrettyRepr):
276
+ def __pretty_repr__(self):
277
+ yield PrettyType(
278
+ type='Complex',
279
+ start='{\n',
280
+ end='\n}',
281
+ elem_indent=' ',
282
+ value_sep=' => '
283
+ )
284
+ yield PrettyAttr('first', 'value1')
285
+ yield PrettyAttr('second', {'nested': 'dict'})
286
+
287
+ obj = ComplexRepr()
288
+ result = pretty_repr_object(obj)
289
+ self.assertIn('Complex', result)
290
+ self.assertIn('first => ', result)
291
+ self.assertIn('second => ', result)
292
+
293
+
294
+ class TestMappingReprMixin(unittest.TestCase):
295
+ """Test cases for MappingReprMixin."""
296
+
297
+ def test_basic_mapping(self):
298
+ """Test basic mapping representation."""
299
+
300
+ class MyMapping(dict, MappingReprMixin):
301
+ pass
302
+
303
+ m = MyMapping({'a': 1, 'b': 2})
304
+ # Get the pretty repr items - MappingReprMixin only provides __pretty_repr__
305
+ # but needs to be mixed with a dict-like class
306
+ items = list(m.__pretty_repr__())
307
+
308
+ # Check first item is PrettyType
309
+ self.assertIsInstance(items[0], PrettyType)
310
+ self.assertEqual(items[0].value_sep, ': ')
311
+ self.assertEqual(items[0].start, '{')
312
+ self.assertEqual(items[0].end, '}')
313
+
314
+ # Check that we have the expected number of items
315
+ self.assertEqual(len(items), 3) # PrettyType + 2 attrs
316
+
317
+ # Check attributes
318
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
319
+ self.assertEqual(len(attr_items), 2)
320
+
321
+ # Keys should be repr'd (with quotes for strings)
322
+ keys = [item.key for item in attr_items]
323
+ self.assertIn("'a'", keys)
324
+ self.assertIn("'b'", keys)
325
+
326
+ def test_empty_mapping(self):
327
+ """Test empty mapping representation."""
328
+
329
+ class MyMapping(dict, MappingReprMixin):
330
+ pass
331
+
332
+ m = MyMapping()
333
+ items = list(m.__pretty_repr__())
334
+ self.assertEqual(len(items), 1) # Only PrettyType, no attributes
335
+
336
+
337
+ class TestPrettyMapping(unittest.TestCase):
338
+ """Test cases for PrettyMapping class."""
339
+
340
+ def test_basic_pretty_mapping(self):
341
+ """Test basic PrettyMapping."""
342
+ pm = PrettyMapping({'x': 10, 'y': 20})
343
+ result = repr(pm)
344
+ self.assertIn("'x': 10", result)
345
+ self.assertIn("'y': 20", result)
346
+
347
+ def test_pretty_mapping_with_type_name(self):
348
+ """Test PrettyMapping with custom type name."""
349
+ pm = PrettyMapping({'a': 1}, type_name='MyDict')
350
+ result = repr(pm)
351
+ self.assertIn('MyDict', result)
352
+ self.assertIn("'a': 1", result)
353
+
354
+ def test_empty_pretty_mapping(self):
355
+ """Test empty PrettyMapping."""
356
+ pm = PrettyMapping({})
357
+ result = repr(pm)
358
+ self.assertIn('{', result)
359
+ self.assertIn('}', result)
360
+
361
+ def test_nested_mapping(self):
362
+ """Test PrettyMapping with nested values."""
363
+ pm = PrettyMapping({
364
+ 'simple': 42,
365
+ 'nested': {'inner': 'value'},
366
+ 'list': [1, 2, 3]
367
+ })
368
+ result = repr(pm)
369
+ self.assertIn("'simple': 42", result)
370
+ self.assertIn("'nested':", result)
371
+ self.assertIn("'list': [1, 2, 3]", result)
372
+
373
+
374
+ class TestPrettyReprContext(unittest.TestCase):
375
+ """Test cases for PrettyReprContext."""
376
+
377
+ def test_initial_state(self):
378
+ """Test initial state of context."""
379
+ ctx = PrettyReprContext()
380
+ self.assertIsNone(ctx.seen_modules_repr)
381
+
382
+ def test_thread_local_behavior(self):
383
+ """Test that context is thread-local."""
384
+ ctx1 = PrettyReprContext()
385
+ ctx2 = PrettyReprContext()
386
+
387
+ ctx1.seen_modules_repr = {'test': 1}
388
+ self.assertIsNone(ctx2.seen_modules_repr)
389
+
390
+
391
+ class TestYieldUniquePrettyReprItems(unittest.TestCase):
392
+ """Test cases for yield_unique_pretty_repr_items function."""
393
+
394
+ def test_basic_usage(self):
395
+ """Test basic usage with simple object."""
396
+
397
+ @dataclasses.dataclass
398
+ class SimpleObject:
399
+ value: int
400
+
401
+ obj = SimpleObject(42)
402
+ items = list(yield_unique_pretty_repr_items(obj))
403
+
404
+ # Should yield PrettyType first
405
+ self.assertIsInstance(items[0], PrettyType)
406
+ self.assertEqual(items[0].type, SimpleObject)
407
+
408
+ # Should yield attribute
409
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
410
+ self.assertTrue(any(item.key == 'value' for item in attr_items))
411
+
412
+ def test_custom_repr_functions(self):
413
+ """Test with custom repr functions."""
414
+
415
+ def custom_repr_object(node):
416
+ yield PrettyType(type='CustomType', start='<', end='>')
417
+
418
+ def custom_repr_attr(node):
419
+ yield PrettyAttr('custom_attr', 'custom_value')
420
+
421
+ obj = object()
422
+ items = list(yield_unique_pretty_repr_items(
423
+ obj,
424
+ repr_object=custom_repr_object,
425
+ repr_attr=custom_repr_attr
426
+ ))
427
+
428
+ self.assertIsInstance(items[0], PrettyType)
429
+ self.assertEqual(items[0].type, 'CustomType')
430
+
431
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
432
+ self.assertTrue(any(item.key == 'custom_attr' for item in attr_items))
433
+
434
+ def test_circular_reference_handling(self):
435
+ """Test handling of circular references."""
436
+
437
+ class Node:
438
+ def __init__(self, value):
439
+ self.value = value
440
+ self.next = None
441
+
442
+ # Create circular reference
443
+ node1 = Node(1)
444
+ node2 = Node(2)
445
+ node1.next = node2
446
+ node2.next = node1
447
+
448
+ # Test that within same context, circular reference is detected
449
+ from brainstate.util._pretty_repr import CONTEXT
450
+
451
+ # Clean start
452
+ CONTEXT.seen_modules_repr = None
453
+
454
+ # Set up context to track seen objects
455
+ CONTEXT.seen_modules_repr = {}
456
+
457
+ # First pass - node1 will be added to seen
458
+ items1 = list(yield_unique_pretty_repr_items(node1))
459
+
460
+ # Second pass - should detect node1 is already seen
461
+ items2 = list(yield_unique_pretty_repr_items(node1))
462
+
463
+ # Check that second pass detected circular reference
464
+ type_items = [item for item in items2 if isinstance(item, PrettyType)]
465
+ self.assertTrue(len(type_items) > 0)
466
+ self.assertTrue(any(item.empty_repr == '...' for item in type_items))
467
+
468
+ # Clean up
469
+ CONTEXT.seen_modules_repr = None
470
+
471
+ def test_context_cleanup(self):
472
+ """Test that context is properly cleaned up."""
473
+ from brainstate.util._pretty_repr import CONTEXT
474
+
475
+ # Clean up any previous state
476
+ CONTEXT.seen_modules_repr = None
477
+
478
+ # Use a class instance that has __dict__
479
+ class TestObj:
480
+ def __init__(self):
481
+ self.test = 'value'
482
+
483
+ obj = TestObj()
484
+ list(yield_unique_pretty_repr_items(obj))
485
+
486
+ # Context should be cleaned up after
487
+ self.assertIsNone(CONTEXT.seen_modules_repr)
488
+
489
+ def test_nested_calls(self):
490
+ """Test nested calls don't recreate context."""
491
+ from brainstate.util._pretty_repr import CONTEXT
492
+
493
+ # Clean up any previous state
494
+ CONTEXT.seen_modules_repr = None
495
+
496
+ class Outer:
497
+ def __init__(self):
498
+ self.inner = Inner()
499
+
500
+ class Inner:
501
+ def __init__(self):
502
+ self.value = 42
503
+
504
+ def repr_outer_attr(node):
505
+ # This will trigger nested yield_unique_pretty_repr_items
506
+ for item in yield_unique_pretty_repr_items(node.inner):
507
+ pass
508
+ yield PrettyAttr('inner', node.inner)
509
+
510
+ outer = Outer()
511
+
512
+ # This should not error and should handle nested calls properly
513
+ items = list(yield_unique_pretty_repr_items(
514
+ outer,
515
+ repr_attr=repr_outer_attr
516
+ ))
517
+
518
+ # Clean up should happen
519
+ self.assertIsNone(CONTEXT.seen_modules_repr)
520
+
521
+
522
+ class TestDefaultReprFunctions(unittest.TestCase):
523
+ """Test cases for default repr functions."""
524
+
525
+ def test_default_repr_object(self):
526
+ """Test _default_repr_object function."""
527
+
528
+ class MyClass:
529
+ pass
530
+
531
+ obj = MyClass()
532
+ items = list(_default_repr_object(obj))
533
+
534
+ self.assertEqual(len(items), 1)
535
+ self.assertIsInstance(items[0], PrettyType)
536
+ self.assertEqual(items[0].type, MyClass)
537
+
538
+ def test_default_repr_attr(self):
539
+ """Test _default_repr_attr function."""
540
+
541
+ class MyClass:
542
+ def __init__(self):
543
+ self.public_attr = 'public'
544
+ self._private_attr = 'private'
545
+ self.__dunder_attr = 'dunder'
546
+ self.number = 42
547
+ self.list_attr = [1, 2, 3]
548
+
549
+ obj = MyClass()
550
+ items = list(_default_repr_attr(obj))
551
+
552
+ # Should include public attributes
553
+ attr_keys = {item.key for item in items}
554
+ self.assertIn('public_attr', attr_keys)
555
+ self.assertIn('number', attr_keys)
556
+ self.assertIn('list_attr', attr_keys)
557
+
558
+ # Should exclude private attributes
559
+ self.assertNotIn('_private_attr', attr_keys)
560
+ self.assertNotIn('__dunder_attr', attr_keys)
561
+
562
+ # Check values are repr'd
563
+ public_item = next(item for item in items if item.key == 'public_attr')
564
+ self.assertEqual(public_item.value, "'public'")
565
+
566
+ number_item = next(item for item in items if item.key == 'number')
567
+ self.assertEqual(number_item.value, '42')
568
+
569
+ def test_default_repr_attr_no_vars(self):
570
+ """Test _default_repr_attr with object that has no __dict__."""
571
+
572
+ class NoVars:
573
+ __slots__ = ('x', 'y')
574
+
575
+ def __init__(self):
576
+ self.x = 1
577
+ self.y = 2
578
+
579
+ obj = NoVars()
580
+ # vars() will raise TypeError for objects without __dict__
581
+ with self.assertRaises(TypeError):
582
+ list(_default_repr_attr(obj))
583
+
584
+
585
+ class TestIntegration(unittest.TestCase):
586
+ """Integration tests for the pretty_repr module."""
587
+
588
+ def test_complex_nested_structure(self):
589
+ """Test complex nested structure representation."""
590
+
591
+ class Container(PrettyRepr):
592
+ def __init__(self, name, children=None):
593
+ self.name = name
594
+ self.children = children or []
595
+
596
+ def __pretty_repr__(self):
597
+ yield PrettyType(type=self.__class__.__name__, start='[', end=']')
598
+ yield PrettyAttr('name', self.name)
599
+ if self.children:
600
+ yield PrettyAttr('children', self.children)
601
+
602
+ # Create nested structure
603
+ leaf1 = Container('leaf1')
604
+ leaf2 = Container('leaf2')
605
+ branch = Container('branch', [leaf1, leaf2])
606
+ root = Container('root', [branch])
607
+
608
+ result = repr(root)
609
+ self.assertIn('Container', result)
610
+ self.assertIn('name=', result)
611
+ self.assertIn('root', result)
612
+ self.assertIn('children=', result)
613
+
614
+ def test_mixed_types_representation(self):
615
+ """Test representation with mixed types."""
616
+
617
+ class MixedTypes(PrettyRepr):
618
+ def __init__(self):
619
+ self.string = "hello"
620
+ self.number = 42
621
+ self.float_num = 3.14
622
+ self.bool_val = True
623
+ self.none_val = None
624
+ self.list_val = [1, 2, 3]
625
+ self.dict_val = {'key': 'value'}
626
+ self.tuple_val = (1, 2)
627
+ self.set_val = {1, 2, 3}
628
+
629
+ def __pretty_repr__(self):
630
+ yield PrettyType(type='MixedTypes')
631
+ for key, value in vars(self).items():
632
+ yield PrettyAttr(key, value)
633
+
634
+ obj = MixedTypes()
635
+ result = repr(obj)
636
+
637
+ # Check all types are represented correctly
638
+ # Note: string values passed to PrettyAttr are not re-quoted
639
+ self.assertIn("string=hello", result)
640
+ self.assertIn("number=42", result)
641
+ self.assertIn("float_num=3.14", result)
642
+ self.assertIn("bool_val=True", result)
643
+ self.assertIn("none_val=None", result)
644
+ self.assertIn("list_val=[1, 2, 3]", result)
645
+ self.assertIn("dict_val={'key': 'value'}", result)
646
+
647
+ def test_custom_formatting_styles(self):
648
+ """Test various custom formatting styles."""
649
+
650
+ class XMLStyle(PrettyRepr):
651
+ def __init__(self, tag, content):
652
+ self.tag = tag
653
+ self.content = content
654
+
655
+ def __pretty_repr__(self):
656
+ yield PrettyType(
657
+ type='',
658
+ start=f'<{self.tag}>',
659
+ end=f'</{self.tag}>',
660
+ value_sep='',
661
+ elem_indent='',
662
+ empty_repr=''
663
+ )
664
+ yield PrettyAttr('', self.content)
665
+
666
+ obj = XMLStyle('div', 'Hello World')
667
+ result = repr(obj)
668
+ self.assertIn('<div>', result)
669
+ self.assertIn('</div>', result)
670
+ self.assertIn('Hello World', result)
671
+
672
+ def test_unicode_handling(self):
673
+ """Test handling of unicode characters."""
674
+
675
+ class UnicodeObj(PrettyRepr):
676
+ def __init__(self):
677
+ self.emoji = "🎉"
678
+ self.chinese = "你好"
679
+ self.special = "café"
680
+
681
+ def __pretty_repr__(self):
682
+ yield PrettyType(type='UnicodeObj')
683
+ for key, value in vars(self).items():
684
+ yield PrettyAttr(key, value)
685
+
686
+ obj = UnicodeObj()
687
+ result = repr(obj)
688
+
689
+ # Unicode should be preserved (string values are not re-quoted)
690
+ self.assertIn("emoji=🎉", result)
691
+ self.assertIn("chinese=你好", result)
692
+ self.assertIn("special=café", result)
693
+
694
+
695
+ if __name__ == '__main__':
696
+ unittest.main()