brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,148 +1,340 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- from contextlib import contextmanager
20
- from functools import partial
21
- from typing import Iterable, Hashable, TypeVar, Callable
22
-
23
- import jax
24
-
25
- __all__ = [
26
- 'ClosedJaxpr',
27
- 'Primitive',
28
- 'extend_axis_env_nd',
29
- 'jaxpr_as_fun',
30
- 'get_aval',
31
- 'Tracer',
32
- 'to_concrete_aval',
33
- 'safe_map',
34
- 'safe_zip',
35
- 'unzip2',
36
- 'wraps',
37
- 'Device',
38
- 'wrap_init',
39
- ]
40
-
41
- T = TypeVar("T")
42
- T1 = TypeVar("T1")
43
- T2 = TypeVar("T2")
44
- T3 = TypeVar("T3")
45
-
46
- from saiunit._compatible_import import wrap_init
47
-
48
- from jax.core import get_aval, Tracer
49
-
50
- if jax.__version_info__ < (0, 5, 0):
51
- from jax.lib.xla_client import Device
52
- else:
53
- from jax import Device
54
-
55
- if jax.__version_info__ < (0, 4, 38):
56
- from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
57
- else:
58
- from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
59
- from jax.core import trace_ctx
60
-
61
-
62
- @contextmanager
63
- def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
64
- prev = trace_ctx.axis_env
65
- try:
66
- trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
67
- yield
68
- finally:
69
- trace_ctx.set_axis_env(prev)
70
-
71
- if jax.__version_info__ < (0, 6, 0):
72
- from jax.util import safe_map, safe_zip, unzip2, wraps
73
-
74
- else:
75
- def safe_map(f, *args):
76
- args = list(map(list, args))
77
- n = len(args[0])
78
- for arg in args[1:]:
79
- assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
80
- return list(map(f, *args))
81
-
82
-
83
- def safe_zip(*args):
84
- args = list(map(list, args))
85
- n = len(args[0])
86
- for arg in args[1:]:
87
- assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
88
- return list(zip(*args))
89
-
90
-
91
- def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
92
- """Unzip sequence of length-2 tuples into two tuples."""
93
- # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
94
- # is too permissive about inputs, and does not guarantee a length-2 output.
95
- xs: list[T1] = []
96
- ys: list[T2] = []
97
- for x, y in xys:
98
- xs.append(x)
99
- ys.append(y)
100
- return tuple(xs), tuple(ys)
101
-
102
-
103
- def fun_name(fun: Callable):
104
- name = getattr(fun, "__name__", None)
105
- if name is not None:
106
- return name
107
- if isinstance(fun, partial):
108
- return fun_name(fun.func)
109
- else:
110
- return "<unnamed function>"
111
-
112
-
113
- def wraps(
114
- wrapped: Callable,
115
- namestr: str | None = None,
116
- docstr: str | None = None,
117
- **kwargs,
118
- ) -> Callable[[T], T]:
119
- """
120
- Like functools.wraps, but with finer-grained control over the name and docstring
121
- of the resulting function.
122
- """
123
-
124
- def wrapper(fun: T) -> T:
125
- try:
126
- name = fun_name(wrapped)
127
- doc = getattr(wrapped, "__doc__", "") or ""
128
- fun.__dict__.update(getattr(wrapped, "__dict__", {}))
129
- fun.__annotations__ = getattr(wrapped, "__annotations__", {})
130
- fun.__name__ = name if namestr is None else namestr.format(fun=name)
131
- fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
132
- fun.__doc__ = (doc if docstr is None
133
- else docstr.format(fun=name, doc=doc, **kwargs))
134
- fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
135
- fun.__wrapped__ = wrapped
136
- except Exception:
137
- pass
138
- return fun
139
-
140
- return wrapper
141
-
142
-
143
- def to_concrete_aval(aval):
144
- aval = get_aval(aval)
145
- if isinstance(aval, Tracer):
146
- return aval.to_concrete_value()
147
- return aval
148
-
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ """
19
+ Compatibility layer for JAX version differences.
20
+
21
+ This module provides a compatibility layer to handle differences between various
22
+ versions of JAX, ensuring that BrainState works correctly across different JAX
23
+ versions. It imports the appropriate modules and functions based on the detected
24
+ JAX version and provides fallback implementations when necessary.
25
+
26
+ Key Features:
27
+ - Version-aware imports for JAX core functionality
28
+ - Compatibility wrappers for changed APIs
29
+ - Fallback implementations for deprecated functions
30
+ - Type-safe utility functions
31
+
32
+ Examples:
33
+ Basic usage:
34
+
35
+ >>> from brainstate._compatible_import import safe_map, safe_zip
36
+ >>> result = safe_map(lambda x: x * 2, [1, 2, 3])
37
+ >>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
38
+
39
+ Using JAX core types:
40
+
41
+ >>> from brainstate._compatible_import import Primitive, ClosedJaxpr
42
+ >>> # These imports work across different JAX versions
43
+ """
44
+
45
+ from contextlib import contextmanager
46
+ from functools import partial
47
+ from typing import Iterable, Hashable, TypeVar, Callable
48
+
49
+ import jax
50
+ from jax.core import get_aval, Tracer
51
+ from saiunit._compatible_import import wrap_init
52
+
53
+ __all__ = [
54
+ 'ClosedJaxpr',
55
+ 'Primitive',
56
+ 'extend_axis_env_nd',
57
+ 'jaxpr_as_fun',
58
+ 'get_aval',
59
+ 'Tracer',
60
+ 'to_concrete_aval',
61
+ 'safe_map',
62
+ 'safe_zip',
63
+ 'unzip2',
64
+ 'wraps',
65
+ 'Device',
66
+ 'wrap_init',
67
+ 'Var',
68
+ 'JaxprEqn',
69
+ 'Jaxpr',
70
+ 'Literal',
71
+
72
+ 'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
73
+ ]
74
+
75
+ T = TypeVar("T")
76
+ T1 = TypeVar("T1")
77
+ T2 = TypeVar("T2")
78
+ T3 = TypeVar("T3")
79
+
80
+ if jax.__version_info__ < (0, 5, 0):
81
+ from jax.lib.xla_client import Device
82
+ else:
83
+ from jax import Device
84
+
85
+ if jax.__version_info__ < (0, 7, 1):
86
+ from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
87
+ else:
88
+ from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
89
+
90
+ if jax.__version_info__ < (0, 4, 38):
91
+ from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
92
+ from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
93
+ else:
94
+ from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
95
+ from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
96
+ from jax.core import trace_ctx
97
+
98
+
99
+ @contextmanager
100
+ def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
101
+ """
102
+ Context manager to temporarily extend the JAX axis environment.
103
+
104
+ Extends the current JAX axis environment with new named axes for
105
+ vectorized computations, then restores the previous environment.
106
+
107
+ Args:
108
+ name_size_pairs: Iterable of (name, size) tuples specifying
109
+ the named axes to add to the environment.
110
+
111
+ Yields:
112
+ None: Context with extended axis environment.
113
+
114
+ Examples:
115
+ >>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
116
+ ... # Code using vectorized operations with named axes
117
+ ... pass
118
+ """
119
+ prev = trace_ctx.axis_env
120
+ try:
121
+ trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
122
+ yield
123
+ finally:
124
+ trace_ctx.set_axis_env(prev)
125
+
126
+ if jax.__version_info__ < (0, 6, 0):
127
+ from jax.util import safe_map, safe_zip, unzip2, wraps
128
+
129
+ else:
130
+ def safe_map(f, *args):
131
+ """
132
+ Map a function over multiple sequences with length checking.
133
+
134
+ Applies a function to corresponding elements from multiple sequences,
135
+ ensuring all sequences have the same length.
136
+
137
+ Args:
138
+ f: Function to apply to elements from each sequence.
139
+ *args: Variable number of sequences to map over.
140
+
141
+ Returns:
142
+ list: Results of applying f to corresponding elements.
143
+
144
+ Raises:
145
+ AssertionError: If input sequences have different lengths.
146
+
147
+ Examples:
148
+ >>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
149
+ [5, 7, 9]
150
+
151
+ >>> safe_map(str.upper, ['a', 'b', 'c'])
152
+ ['A', 'B', 'C']
153
+ """
154
+ args = list(map(list, args))
155
+ n = len(args[0])
156
+ for arg in args[1:]:
157
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
158
+ return list(map(f, *args))
159
+
160
+
161
+ def safe_zip(*args):
162
+ """
163
+ Zip multiple sequences with length checking.
164
+
165
+ Combines corresponding elements from multiple sequences into tuples,
166
+ ensuring all sequences have the same length.
167
+
168
+ Args:
169
+ *args: Variable number of sequences to zip together.
170
+
171
+ Returns:
172
+ list: List of tuples containing corresponding elements.
173
+
174
+ Raises:
175
+ AssertionError: If input sequences have different lengths.
176
+
177
+ Examples:
178
+ >>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
179
+ [(1, 'a'), (2, 'b'), (3, 'c')]
180
+
181
+ >>> safe_zip([1, 2], [3, 4], [5, 6])
182
+ [(1, 3, 5), (2, 4, 6)]
183
+ """
184
+ args = list(map(list, args))
185
+ n = len(args[0])
186
+ for arg in args[1:]:
187
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
188
+ return list(zip(*args))
189
+
190
+
191
+ def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
192
+ """
193
+ Unzip sequence of length-2 tuples into two tuples.
194
+
195
+ Takes an iterable of 2-tuples and separates them into two tuples
196
+ containing the first and second elements respectively.
197
+
198
+ Args:
199
+ xys: Iterable of 2-tuples to unzip.
200
+
201
+ Returns:
202
+ tuple: A 2-tuple containing:
203
+ - Tuple of all first elements
204
+ - Tuple of all second elements
205
+
206
+ Examples:
207
+ >>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
208
+ >>> nums, letters = unzip2(pairs)
209
+ >>> nums
210
+ (1, 2, 3)
211
+ >>> letters
212
+ ('a', 'b', 'c')
213
+
214
+ Notes:
215
+ We deliberately don't use zip(*xys) because it is lazily evaluated,
216
+ is too permissive about inputs, and does not guarantee a length-2 output.
217
+ """
218
+ # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
219
+ # is too permissive about inputs, and does not guarantee a length-2 output.
220
+ xs: list[T1] = []
221
+ ys: list[T2] = []
222
+ for x, y in xys:
223
+ xs.append(x)
224
+ ys.append(y)
225
+ return tuple(xs), tuple(ys)
226
+
227
+
228
+ def fun_name(fun: Callable):
229
+ """
230
+ Extract the name of a function, handling special cases.
231
+
232
+ Attempts to get the name of a function, with special handling for
233
+ partial functions and fallback for unnamed functions.
234
+
235
+ Args:
236
+ fun: The function to get the name from.
237
+
238
+ Returns:
239
+ str: The function name, or "<unnamed function>" if no name available.
240
+
241
+ Examples:
242
+ >>> def my_function():
243
+ ... pass
244
+ >>> fun_name(my_function)
245
+ 'my_function'
246
+
247
+ >>> from functools import partial
248
+ >>> add = lambda x, y: x + y
249
+ >>> add_one = partial(add, 1)
250
+ >>> fun_name(add_one)
251
+ '<lambda>'
252
+ """
253
+ name = getattr(fun, "__name__", None)
254
+ if name is not None:
255
+ return name
256
+ if isinstance(fun, partial):
257
+ return fun_name(fun.func)
258
+ else:
259
+ return "<unnamed function>"
260
+
261
+
262
+ def wraps(
263
+ wrapped: Callable,
264
+ namestr: str | None = None,
265
+ docstr: str | None = None,
266
+ **kwargs,
267
+ ) -> Callable[[T], T]:
268
+ """
269
+ Enhanced function wrapper with fine-grained control.
270
+
271
+ Like functools.wraps, but provides more control over the name and docstring
272
+ of the resulting function. Useful for creating custom decorators.
273
+
274
+ Args:
275
+ wrapped: The function being wrapped.
276
+ namestr: Optional format string for the wrapper function name.
277
+ Can use {fun} placeholder for the original function name.
278
+ docstr: Optional format string for the wrapper function docstring.
279
+ Can use {fun}, {doc}, and other kwargs as placeholders.
280
+ **kwargs: Additional keyword arguments for format string substitution.
281
+
282
+ Returns:
283
+ Callable: A decorator function that applies the wrapping.
284
+
285
+ Examples:
286
+ >>> def my_decorator(func):
287
+ ... @wraps(func, namestr="decorated_{fun}")
288
+ ... def wrapper(*args, **kwargs):
289
+ ... return func(*args, **kwargs)
290
+ ... return wrapper
291
+
292
+ >>> @my_decorator
293
+ ... def example():
294
+ ... pass
295
+ >>> example.__name__
296
+ 'decorated_example'
297
+ """
298
+
299
+ def wrapper(fun: T) -> T:
300
+ try:
301
+ name = fun_name(wrapped)
302
+ doc = getattr(wrapped, "__doc__", "") or ""
303
+ fun.__dict__.update(getattr(wrapped, "__dict__", {}))
304
+ fun.__annotations__ = getattr(wrapped, "__annotations__", {})
305
+ fun.__name__ = name if namestr is None else namestr.format(fun=name)
306
+ fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
307
+ fun.__doc__ = (doc if docstr is None
308
+ else docstr.format(fun=name, doc=doc, **kwargs))
309
+ fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
310
+ fun.__wrapped__ = wrapped
311
+ except Exception:
312
+ pass
313
+ return fun
314
+
315
+ return wrapper
316
+
317
+
318
+ def to_concrete_aval(aval):
319
+ """
320
+ Convert an abstract value to its concrete representation.
321
+
322
+ Takes an abstract value and attempts to convert it to a concrete value,
323
+ handling JAX Tracer objects appropriately.
324
+
325
+ Args:
326
+ aval: The abstract value to convert.
327
+
328
+ Returns:
329
+ The concrete value representation, or the original aval if already concrete.
330
+
331
+ Examples:
332
+ >>> import jax.numpy as jnp
333
+ >>> arr = jnp.array([1, 2, 3])
334
+ >>> concrete = to_concrete_aval(arr)
335
+ # Returns the concrete array value
336
+ """
337
+ aval = get_aval(aval)
338
+ if isinstance(aval, Tracer):
339
+ return aval.to_concrete_value()
340
+ return aval