brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,696 @@
1
+ # Copyright 2024 BrainState Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Comprehensive tests for pretty_repr module.
17
+ """
18
+
19
+ import dataclasses
20
+ import unittest
21
+ from typing import Iterator, Union
22
+
23
+ from brainstate.util._pretty_repr import (
24
+ PrettyType,
25
+ PrettyAttr,
26
+ PrettyRepr,
27
+ pretty_repr_elem,
28
+ pretty_repr_object,
29
+ MappingReprMixin,
30
+ PrettyMapping,
31
+ PrettyReprContext,
32
+ yield_unique_pretty_repr_items,
33
+ _default_repr_object,
34
+ _default_repr_attr,
35
+ )
36
+
37
+
38
+ class TestPrettyType(unittest.TestCase):
39
+ """Test cases for PrettyType dataclass."""
40
+
41
+ def test_default_values(self):
42
+ """Test PrettyType with default values."""
43
+ pt = PrettyType(type='MyClass')
44
+ self.assertEqual(pt.type, 'MyClass')
45
+ self.assertEqual(pt.start, '(')
46
+ self.assertEqual(pt.end, ')')
47
+ self.assertEqual(pt.value_sep, '=')
48
+ self.assertEqual(pt.elem_indent, ' ')
49
+ self.assertEqual(pt.empty_repr, '')
50
+
51
+ def test_custom_values(self):
52
+ """Test PrettyType with custom values."""
53
+ pt = PrettyType(
54
+ type=dict,
55
+ start='{',
56
+ end='}',
57
+ value_sep=': ',
58
+ elem_indent=' ',
59
+ empty_repr='<empty>'
60
+ )
61
+ self.assertEqual(pt.type, dict)
62
+ self.assertEqual(pt.start, '{')
63
+ self.assertEqual(pt.end, '}')
64
+ self.assertEqual(pt.value_sep, ': ')
65
+ self.assertEqual(pt.elem_indent, ' ')
66
+ self.assertEqual(pt.empty_repr, '<empty>')
67
+
68
+ def test_type_can_be_string_or_class(self):
69
+ """Test that type can be either string or class."""
70
+ pt1 = PrettyType(type='StringType')
71
+ self.assertIsInstance(pt1.type, str)
72
+
73
+ pt2 = PrettyType(type=list)
74
+ self.assertEqual(pt2.type, list)
75
+
76
+
77
+ class TestPrettyAttr(unittest.TestCase):
78
+ """Test cases for PrettyAttr dataclass."""
79
+
80
+ def test_default_values(self):
81
+ """Test PrettyAttr with default values."""
82
+ pa = PrettyAttr(key='name', value='test')
83
+ self.assertEqual(pa.key, 'name')
84
+ self.assertEqual(pa.value, 'test')
85
+ self.assertEqual(pa.start, '')
86
+ self.assertEqual(pa.end, '')
87
+
88
+ def test_custom_values(self):
89
+ """Test PrettyAttr with custom values."""
90
+ pa = PrettyAttr(key='count', value=42, start='[', end=']')
91
+ self.assertEqual(pa.key, 'count')
92
+ self.assertEqual(pa.value, 42)
93
+ self.assertEqual(pa.start, '[')
94
+ self.assertEqual(pa.end, ']')
95
+
96
+ def test_value_types(self):
97
+ """Test PrettyAttr with various value types."""
98
+ pa1 = PrettyAttr('str_value', 'string')
99
+ self.assertEqual(pa1.value, 'string')
100
+
101
+ pa2 = PrettyAttr('int_value', 123)
102
+ self.assertEqual(pa2.value, 123)
103
+
104
+ pa3 = PrettyAttr('list_value', [1, 2, 3])
105
+ self.assertEqual(pa3.value, [1, 2, 3])
106
+
107
+ pa4 = PrettyAttr('dict_value', {'a': 1})
108
+ self.assertEqual(pa4.value, {'a': 1})
109
+
110
+
111
+ class SimplePrettyRepr(PrettyRepr):
112
+ """Simple implementation of PrettyRepr for testing."""
113
+
114
+ def __init__(self, value, name='SimpleObject'):
115
+ self.value = value
116
+ self.name = name
117
+
118
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
119
+ yield PrettyType(type=self.name)
120
+ yield PrettyAttr('value', self.value)
121
+
122
+
123
+ class TestPrettyRepr(unittest.TestCase):
124
+ """Test cases for PrettyRepr abstract class."""
125
+
126
+ def test_simple_repr(self):
127
+ """Test simple pretty representation."""
128
+ obj = SimplePrettyRepr(42)
129
+ result = repr(obj)
130
+ self.assertIn('SimpleObject', result)
131
+ self.assertIn('value=42', result)
132
+
133
+ def test_custom_type_config(self):
134
+ """Test PrettyRepr with custom type configuration."""
135
+
136
+ class CustomRepr(PrettyRepr):
137
+ def __init__(self, data):
138
+ self.data = data
139
+
140
+ def __pretty_repr__(self):
141
+ yield PrettyType(type='CustomObject', start='<', end='>', value_sep=' -> ')
142
+ yield PrettyAttr('data', self.data)
143
+
144
+ obj = CustomRepr({'key': 'value'})
145
+ result = repr(obj)
146
+ self.assertIn('CustomObject<', result)
147
+ self.assertIn('data -> ', result)
148
+ self.assertIn('>', result)
149
+
150
+ def test_multiple_attributes(self):
151
+ """Test PrettyRepr with multiple attributes."""
152
+
153
+ class MultiAttrRepr(PrettyRepr):
154
+ def __init__(self, a, b, c):
155
+ self.a = a
156
+ self.b = b
157
+ self.c = c
158
+
159
+ def __pretty_repr__(self):
160
+ yield PrettyType(type=self.__class__)
161
+ yield PrettyAttr('a', self.a)
162
+ yield PrettyAttr('b', self.b)
163
+ yield PrettyAttr('c', self.c)
164
+
165
+ obj = MultiAttrRepr(1, 'two', [3])
166
+ result = repr(obj)
167
+ self.assertIn('MultiAttrRepr', result)
168
+ self.assertIn('a=1', result)
169
+ self.assertIn("b=two", result) # String value is not re-quoted
170
+ self.assertIn('c=[3]', result)
171
+
172
+ def test_empty_object(self):
173
+ """Test PrettyRepr with no attributes."""
174
+
175
+ class EmptyRepr(PrettyRepr):
176
+ def __pretty_repr__(self):
177
+ yield PrettyType(type='EmptyObject', empty_repr='<no data>')
178
+
179
+ obj = EmptyRepr()
180
+ result = repr(obj)
181
+ self.assertIn('EmptyObject', result)
182
+ self.assertIn('<no data>', result)
183
+
184
+
185
+ class TestPrettyReprElem(unittest.TestCase):
186
+ """Test cases for pretty_repr_elem function."""
187
+
188
+ def test_basic_elem(self):
189
+ """Test basic element formatting."""
190
+ pt = PrettyType(type='Test')
191
+ elem = PrettyAttr('key', 'value')
192
+ result = pretty_repr_elem(pt, elem)
193
+ # Value is already a string, so it's not quoted again
194
+ self.assertEqual(result, " key=value")
195
+
196
+ def test_elem_with_custom_indent(self):
197
+ """Test element with custom indentation."""
198
+ pt = PrettyType(type='Test', elem_indent=' ')
199
+ elem = PrettyAttr('key', 123)
200
+ result = pretty_repr_elem(pt, elem)
201
+ self.assertEqual(result, " key=123")
202
+
203
+ def test_elem_with_custom_separator(self):
204
+ """Test element with custom value separator."""
205
+ pt = PrettyType(type='Test', value_sep=': ')
206
+ elem = PrettyAttr('key', 'value')
207
+ result = pretty_repr_elem(pt, elem)
208
+ self.assertEqual(result, " key: value")
209
+
210
+ def test_elem_with_start_end(self):
211
+ """Test element with start and end markers."""
212
+ pt = PrettyType(type='Test')
213
+ elem = PrettyAttr('key', 'value', start='[', end=']')
214
+ result = pretty_repr_elem(pt, elem)
215
+ self.assertEqual(result, " [key=value]")
216
+
217
+ def test_elem_with_multiline_value(self):
218
+ """Test element with multiline value."""
219
+ pt = PrettyType(type='Test')
220
+ elem = PrettyAttr('key', 'line1\nline2\nline3')
221
+ result = pretty_repr_elem(pt, elem)
222
+ expected = " key=line1\n line2\n line3"
223
+ self.assertEqual(result, expected)
224
+
225
+ def test_elem_invalid_type(self):
226
+ """Test that non-PrettyAttr raises TypeError."""
227
+ pt = PrettyType(type='Test')
228
+ with self.assertRaises(TypeError) as cm:
229
+ pretty_repr_elem(pt, "not a PrettyAttr")
230
+ self.assertIn("Item must be Elem", str(cm.exception))
231
+
232
+
233
+ class TestPrettyReprObject(unittest.TestCase):
234
+ """Test cases for pretty_repr_object function."""
235
+
236
+ def test_valid_object(self):
237
+ """Test with valid PrettyRepr object."""
238
+ obj = SimplePrettyRepr(42, 'TestObject')
239
+ result = pretty_repr_object(obj)
240
+ self.assertIn('TestObject', result)
241
+ self.assertIn('value=42', result)
242
+
243
+ def test_invalid_object(self):
244
+ """Test that non-PrettyRepr object raises TypeError."""
245
+ with self.assertRaises(TypeError) as cm:
246
+ pretty_repr_object("not a PrettyRepr")
247
+ self.assertIn("is not representable", str(cm.exception))
248
+
249
+ def test_invalid_first_item(self):
250
+ """Test that invalid first item raises TypeError."""
251
+
252
+ class InvalidRepr(PrettyRepr):
253
+ def __pretty_repr__(self):
254
+ yield PrettyAttr('key', 'value') # Should yield PrettyType first
255
+
256
+ obj = InvalidRepr()
257
+ with self.assertRaises(TypeError) as cm:
258
+ pretty_repr_object(obj)
259
+ self.assertIn("First item must be PrettyType", str(cm.exception))
260
+
261
+ def test_empty_representation(self):
262
+ """Test object with no attributes."""
263
+
264
+ class EmptyRepr(PrettyRepr):
265
+ def __pretty_repr__(self):
266
+ yield PrettyType(type='Empty', empty_repr='∅')
267
+
268
+ obj = EmptyRepr()
269
+ result = pretty_repr_object(obj)
270
+ self.assertEqual(result, 'Empty(∅)')
271
+
272
+ def test_complex_nested_formatting(self):
273
+ """Test complex nested formatting."""
274
+
275
+ class ComplexRepr(PrettyRepr):
276
+ def __pretty_repr__(self):
277
+ yield PrettyType(
278
+ type='Complex',
279
+ start='{\n',
280
+ end='\n}',
281
+ elem_indent=' ',
282
+ value_sep=' => '
283
+ )
284
+ yield PrettyAttr('first', 'value1')
285
+ yield PrettyAttr('second', {'nested': 'dict'})
286
+
287
+ obj = ComplexRepr()
288
+ result = pretty_repr_object(obj)
289
+ self.assertIn('Complex', result)
290
+ self.assertIn('first => ', result)
291
+ self.assertIn('second => ', result)
292
+
293
+
294
+ class TestMappingReprMixin(unittest.TestCase):
295
+ """Test cases for MappingReprMixin."""
296
+
297
+ def test_basic_mapping(self):
298
+ """Test basic mapping representation."""
299
+
300
+ class MyMapping(dict, MappingReprMixin):
301
+ pass
302
+
303
+ m = MyMapping({'a': 1, 'b': 2})
304
+ # Get the pretty repr items - MappingReprMixin only provides __pretty_repr__
305
+ # but needs to be mixed with a dict-like class
306
+ items = list(m.__pretty_repr__())
307
+
308
+ # Check first item is PrettyType
309
+ self.assertIsInstance(items[0], PrettyType)
310
+ self.assertEqual(items[0].value_sep, ': ')
311
+ self.assertEqual(items[0].start, '{')
312
+ self.assertEqual(items[0].end, '}')
313
+
314
+ # Check that we have the expected number of items
315
+ self.assertEqual(len(items), 3) # PrettyType + 2 attrs
316
+
317
+ # Check attributes
318
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
319
+ self.assertEqual(len(attr_items), 2)
320
+
321
+ # Keys should be repr'd (with quotes for strings)
322
+ keys = [item.key for item in attr_items]
323
+ self.assertIn("'a'", keys)
324
+ self.assertIn("'b'", keys)
325
+
326
+ def test_empty_mapping(self):
327
+ """Test empty mapping representation."""
328
+
329
+ class MyMapping(dict, MappingReprMixin):
330
+ pass
331
+
332
+ m = MyMapping()
333
+ items = list(m.__pretty_repr__())
334
+ self.assertEqual(len(items), 1) # Only PrettyType, no attributes
335
+
336
+
337
+ class TestPrettyMapping(unittest.TestCase):
338
+ """Test cases for PrettyMapping class."""
339
+
340
+ def test_basic_pretty_mapping(self):
341
+ """Test basic PrettyMapping."""
342
+ pm = PrettyMapping({'x': 10, 'y': 20})
343
+ result = repr(pm)
344
+ self.assertIn("'x': 10", result)
345
+ self.assertIn("'y': 20", result)
346
+
347
+ def test_pretty_mapping_with_type_name(self):
348
+ """Test PrettyMapping with custom type name."""
349
+ pm = PrettyMapping({'a': 1}, type_name='MyDict')
350
+ result = repr(pm)
351
+ self.assertIn('MyDict', result)
352
+ self.assertIn("'a': 1", result)
353
+
354
+ def test_empty_pretty_mapping(self):
355
+ """Test empty PrettyMapping."""
356
+ pm = PrettyMapping({})
357
+ result = repr(pm)
358
+ self.assertIn('{', result)
359
+ self.assertIn('}', result)
360
+
361
+ def test_nested_mapping(self):
362
+ """Test PrettyMapping with nested values."""
363
+ pm = PrettyMapping({
364
+ 'simple': 42,
365
+ 'nested': {'inner': 'value'},
366
+ 'list': [1, 2, 3]
367
+ })
368
+ result = repr(pm)
369
+ self.assertIn("'simple': 42", result)
370
+ self.assertIn("'nested':", result)
371
+ self.assertIn("'list': [1, 2, 3]", result)
372
+
373
+
374
+ class TestPrettyReprContext(unittest.TestCase):
375
+ """Test cases for PrettyReprContext."""
376
+
377
+ def test_initial_state(self):
378
+ """Test initial state of context."""
379
+ ctx = PrettyReprContext()
380
+ self.assertIsNone(ctx.seen_modules_repr)
381
+
382
+ def test_thread_local_behavior(self):
383
+ """Test that context is thread-local."""
384
+ ctx1 = PrettyReprContext()
385
+ ctx2 = PrettyReprContext()
386
+
387
+ ctx1.seen_modules_repr = {'test': 1}
388
+ self.assertIsNone(ctx2.seen_modules_repr)
389
+
390
+
391
+ class TestYieldUniquePrettyReprItems(unittest.TestCase):
392
+ """Test cases for yield_unique_pretty_repr_items function."""
393
+
394
+ def test_basic_usage(self):
395
+ """Test basic usage with simple object."""
396
+
397
+ @dataclasses.dataclass
398
+ class SimpleObject:
399
+ value: int
400
+
401
+ obj = SimpleObject(42)
402
+ items = list(yield_unique_pretty_repr_items(obj))
403
+
404
+ # Should yield PrettyType first
405
+ self.assertIsInstance(items[0], PrettyType)
406
+ self.assertEqual(items[0].type, SimpleObject)
407
+
408
+ # Should yield attribute
409
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
410
+ self.assertTrue(any(item.key == 'value' for item in attr_items))
411
+
412
+ def test_custom_repr_functions(self):
413
+ """Test with custom repr functions."""
414
+
415
+ def custom_repr_object(node):
416
+ yield PrettyType(type='CustomType', start='<', end='>')
417
+
418
+ def custom_repr_attr(node):
419
+ yield PrettyAttr('custom_attr', 'custom_value')
420
+
421
+ obj = object()
422
+ items = list(yield_unique_pretty_repr_items(
423
+ obj,
424
+ repr_object=custom_repr_object,
425
+ repr_attr=custom_repr_attr
426
+ ))
427
+
428
+ self.assertIsInstance(items[0], PrettyType)
429
+ self.assertEqual(items[0].type, 'CustomType')
430
+
431
+ attr_items = [item for item in items[1:] if isinstance(item, PrettyAttr)]
432
+ self.assertTrue(any(item.key == 'custom_attr' for item in attr_items))
433
+
434
+ def test_circular_reference_handling(self):
435
+ """Test handling of circular references."""
436
+
437
+ class Node:
438
+ def __init__(self, value):
439
+ self.value = value
440
+ self.next = None
441
+
442
+ # Create circular reference
443
+ node1 = Node(1)
444
+ node2 = Node(2)
445
+ node1.next = node2
446
+ node2.next = node1
447
+
448
+ # Test that within same context, circular reference is detected
449
+ from brainstate.util._pretty_repr import CONTEXT
450
+
451
+ # Clean start
452
+ CONTEXT.seen_modules_repr = None
453
+
454
+ # Set up context to track seen objects
455
+ CONTEXT.seen_modules_repr = {}
456
+
457
+ # First pass - node1 will be added to seen
458
+ items1 = list(yield_unique_pretty_repr_items(node1))
459
+
460
+ # Second pass - should detect node1 is already seen
461
+ items2 = list(yield_unique_pretty_repr_items(node1))
462
+
463
+ # Check that second pass detected circular reference
464
+ type_items = [item for item in items2 if isinstance(item, PrettyType)]
465
+ self.assertTrue(len(type_items) > 0)
466
+ self.assertTrue(any(item.empty_repr == '...' for item in type_items))
467
+
468
+ # Clean up
469
+ CONTEXT.seen_modules_repr = None
470
+
471
+ def test_context_cleanup(self):
472
+ """Test that context is properly cleaned up."""
473
+ from brainstate.util._pretty_repr import CONTEXT
474
+
475
+ # Clean up any previous state
476
+ CONTEXT.seen_modules_repr = None
477
+
478
+ # Use a class instance that has __dict__
479
+ class TestObj:
480
+ def __init__(self):
481
+ self.test = 'value'
482
+
483
+ obj = TestObj()
484
+ list(yield_unique_pretty_repr_items(obj))
485
+
486
+ # Context should be cleaned up after
487
+ self.assertIsNone(CONTEXT.seen_modules_repr)
488
+
489
+ def test_nested_calls(self):
490
+ """Test nested calls don't recreate context."""
491
+ from brainstate.util._pretty_repr import CONTEXT
492
+
493
+ # Clean up any previous state
494
+ CONTEXT.seen_modules_repr = None
495
+
496
+ class Outer:
497
+ def __init__(self):
498
+ self.inner = Inner()
499
+
500
+ class Inner:
501
+ def __init__(self):
502
+ self.value = 42
503
+
504
+ def repr_outer_attr(node):
505
+ # This will trigger nested yield_unique_pretty_repr_items
506
+ for item in yield_unique_pretty_repr_items(node.inner):
507
+ pass
508
+ yield PrettyAttr('inner', node.inner)
509
+
510
+ outer = Outer()
511
+
512
+ # This should not error and should handle nested calls properly
513
+ items = list(yield_unique_pretty_repr_items(
514
+ outer,
515
+ repr_attr=repr_outer_attr
516
+ ))
517
+
518
+ # Clean up should happen
519
+ self.assertIsNone(CONTEXT.seen_modules_repr)
520
+
521
+
522
+ class TestDefaultReprFunctions(unittest.TestCase):
523
+ """Test cases for default repr functions."""
524
+
525
+ def test_default_repr_object(self):
526
+ """Test _default_repr_object function."""
527
+
528
+ class MyClass:
529
+ pass
530
+
531
+ obj = MyClass()
532
+ items = list(_default_repr_object(obj))
533
+
534
+ self.assertEqual(len(items), 1)
535
+ self.assertIsInstance(items[0], PrettyType)
536
+ self.assertEqual(items[0].type, MyClass)
537
+
538
+ def test_default_repr_attr(self):
539
+ """Test _default_repr_attr function."""
540
+
541
+ class MyClass:
542
+ def __init__(self):
543
+ self.public_attr = 'public'
544
+ self._private_attr = 'private'
545
+ self.__dunder_attr = 'dunder'
546
+ self.number = 42
547
+ self.list_attr = [1, 2, 3]
548
+
549
+ obj = MyClass()
550
+ items = list(_default_repr_attr(obj))
551
+
552
+ # Should include public attributes
553
+ attr_keys = {item.key for item in items}
554
+ self.assertIn('public_attr', attr_keys)
555
+ self.assertIn('number', attr_keys)
556
+ self.assertIn('list_attr', attr_keys)
557
+
558
+ # Should exclude private attributes
559
+ self.assertNotIn('_private_attr', attr_keys)
560
+ self.assertNotIn('__dunder_attr', attr_keys)
561
+
562
+ # Check values are repr'd
563
+ public_item = next(item for item in items if item.key == 'public_attr')
564
+ self.assertEqual(public_item.value, "'public'")
565
+
566
+ number_item = next(item for item in items if item.key == 'number')
567
+ self.assertEqual(number_item.value, '42')
568
+
569
+ def test_default_repr_attr_no_vars(self):
570
+ """Test _default_repr_attr with object that has no __dict__."""
571
+
572
+ class NoVars:
573
+ __slots__ = ('x', 'y')
574
+
575
+ def __init__(self):
576
+ self.x = 1
577
+ self.y = 2
578
+
579
+ obj = NoVars()
580
+ # vars() will raise TypeError for objects without __dict__
581
+ with self.assertRaises(TypeError):
582
+ list(_default_repr_attr(obj))
583
+
584
+
585
+ class TestIntegration(unittest.TestCase):
586
+ """Integration tests for the pretty_repr module."""
587
+
588
+ def test_complex_nested_structure(self):
589
+ """Test complex nested structure representation."""
590
+
591
+ class Container(PrettyRepr):
592
+ def __init__(self, name, children=None):
593
+ self.name = name
594
+ self.children = children or []
595
+
596
+ def __pretty_repr__(self):
597
+ yield PrettyType(type=self.__class__.__name__, start='[', end=']')
598
+ yield PrettyAttr('name', self.name)
599
+ if self.children:
600
+ yield PrettyAttr('children', self.children)
601
+
602
+ # Create nested structure
603
+ leaf1 = Container('leaf1')
604
+ leaf2 = Container('leaf2')
605
+ branch = Container('branch', [leaf1, leaf2])
606
+ root = Container('root', [branch])
607
+
608
+ result = repr(root)
609
+ self.assertIn('Container', result)
610
+ self.assertIn('name=', result)
611
+ self.assertIn('root', result)
612
+ self.assertIn('children=', result)
613
+
614
+ def test_mixed_types_representation(self):
615
+ """Test representation with mixed types."""
616
+
617
+ class MixedTypes(PrettyRepr):
618
+ def __init__(self):
619
+ self.string = "hello"
620
+ self.number = 42
621
+ self.float_num = 3.14
622
+ self.bool_val = True
623
+ self.none_val = None
624
+ self.list_val = [1, 2, 3]
625
+ self.dict_val = {'key': 'value'}
626
+ self.tuple_val = (1, 2)
627
+ self.set_val = {1, 2, 3}
628
+
629
+ def __pretty_repr__(self):
630
+ yield PrettyType(type='MixedTypes')
631
+ for key, value in vars(self).items():
632
+ yield PrettyAttr(key, value)
633
+
634
+ obj = MixedTypes()
635
+ result = repr(obj)
636
+
637
+ # Check all types are represented correctly
638
+ # Note: string values passed to PrettyAttr are not re-quoted
639
+ self.assertIn("string=hello", result)
640
+ self.assertIn("number=42", result)
641
+ self.assertIn("float_num=3.14", result)
642
+ self.assertIn("bool_val=True", result)
643
+ self.assertIn("none_val=None", result)
644
+ self.assertIn("list_val=[1, 2, 3]", result)
645
+ self.assertIn("dict_val={'key': 'value'}", result)
646
+
647
+ def test_custom_formatting_styles(self):
648
+ """Test various custom formatting styles."""
649
+
650
+ class XMLStyle(PrettyRepr):
651
+ def __init__(self, tag, content):
652
+ self.tag = tag
653
+ self.content = content
654
+
655
+ def __pretty_repr__(self):
656
+ yield PrettyType(
657
+ type='',
658
+ start=f'<{self.tag}>',
659
+ end=f'</{self.tag}>',
660
+ value_sep='',
661
+ elem_indent='',
662
+ empty_repr=''
663
+ )
664
+ yield PrettyAttr('', self.content)
665
+
666
+ obj = XMLStyle('div', 'Hello World')
667
+ result = repr(obj)
668
+ self.assertIn('<div>', result)
669
+ self.assertIn('</div>', result)
670
+ self.assertIn('Hello World', result)
671
+
672
+ def test_unicode_handling(self):
673
+ """Test handling of unicode characters."""
674
+
675
+ class UnicodeObj(PrettyRepr):
676
+ def __init__(self):
677
+ self.emoji = "🎉"
678
+ self.chinese = "你好"
679
+ self.special = "café"
680
+
681
+ def __pretty_repr__(self):
682
+ yield PrettyType(type='UnicodeObj')
683
+ for key, value in vars(self).items():
684
+ yield PrettyAttr(key, value)
685
+
686
+ obj = UnicodeObj()
687
+ result = repr(obj)
688
+
689
+ # Unicode should be preserved (string values are not re-quoted)
690
+ self.assertIn("emoji=🎉", result)
691
+ self.assertIn("chinese=你好", result)
692
+ self.assertIn("special=café", result)
693
+
694
+
695
+ if __name__ == '__main__':
696
+ unittest.main()