brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,340 +1,340 @@
1
- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- """
19
- Compatibility layer for JAX version differences.
20
-
21
- This module provides a compatibility layer to handle differences between various
22
- versions of JAX, ensuring that BrainState works correctly across different JAX
23
- versions. It imports the appropriate modules and functions based on the detected
24
- JAX version and provides fallback implementations when necessary.
25
-
26
- Key Features:
27
- - Version-aware imports for JAX core functionality
28
- - Compatibility wrappers for changed APIs
29
- - Fallback implementations for deprecated functions
30
- - Type-safe utility functions
31
-
32
- Examples:
33
- Basic usage:
34
-
35
- >>> from brainstate._compatible_import import safe_map, safe_zip
36
- >>> result = safe_map(lambda x: x * 2, [1, 2, 3])
37
- >>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
38
-
39
- Using JAX core types:
40
-
41
- >>> from brainstate._compatible_import import Primitive, ClosedJaxpr
42
- >>> # These imports work across different JAX versions
43
- """
44
-
45
- from contextlib import contextmanager
46
- from functools import partial
47
- from typing import Iterable, Hashable, TypeVar, Callable
48
-
49
- import jax
50
- from jax.core import get_aval, Tracer
51
- from saiunit._compatible_import import wrap_init
52
-
53
- __all__ = [
54
- 'ClosedJaxpr',
55
- 'Primitive',
56
- 'extend_axis_env_nd',
57
- 'jaxpr_as_fun',
58
- 'get_aval',
59
- 'Tracer',
60
- 'to_concrete_aval',
61
- 'safe_map',
62
- 'safe_zip',
63
- 'unzip2',
64
- 'wraps',
65
- 'Device',
66
- 'wrap_init',
67
- 'Var',
68
- 'JaxprEqn',
69
- 'Jaxpr',
70
- 'Literal',
71
-
72
- 'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
73
- ]
74
-
75
- T = TypeVar("T")
76
- T1 = TypeVar("T1")
77
- T2 = TypeVar("T2")
78
- T3 = TypeVar("T3")
79
-
80
- if jax.__version_info__ < (0, 5, 0):
81
- from jax.lib.xla_client import Device
82
- else:
83
- from jax import Device
84
-
85
- if jax.__version_info__ < (0, 7, 1):
86
- from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
87
- else:
88
- from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
89
-
90
- if jax.__version_info__ < (0, 4, 38):
91
- from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
92
- from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
93
- else:
94
- from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
95
- from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
96
- from jax.core import trace_ctx
97
-
98
-
99
- @contextmanager
100
- def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
101
- """
102
- Context manager to temporarily extend the JAX axis environment.
103
-
104
- Extends the current JAX axis environment with new named axes for
105
- vectorized computations, then restores the previous environment.
106
-
107
- Args:
108
- name_size_pairs: Iterable of (name, size) tuples specifying
109
- the named axes to add to the environment.
110
-
111
- Yields:
112
- None: Context with extended axis environment.
113
-
114
- Examples:
115
- >>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
116
- ... # Code using vectorized operations with named axes
117
- ... pass
118
- """
119
- prev = trace_ctx.axis_env
120
- try:
121
- trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
122
- yield
123
- finally:
124
- trace_ctx.set_axis_env(prev)
125
-
126
- if jax.__version_info__ < (0, 6, 0):
127
- from jax.util import safe_map, safe_zip, unzip2, wraps
128
-
129
- else:
130
- def safe_map(f, *args):
131
- """
132
- Map a function over multiple sequences with length checking.
133
-
134
- Applies a function to corresponding elements from multiple sequences,
135
- ensuring all sequences have the same length.
136
-
137
- Args:
138
- f: Function to apply to elements from each sequence.
139
- *args: Variable number of sequences to map over.
140
-
141
- Returns:
142
- list: Results of applying f to corresponding elements.
143
-
144
- Raises:
145
- AssertionError: If input sequences have different lengths.
146
-
147
- Examples:
148
- >>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
149
- [5, 7, 9]
150
-
151
- >>> safe_map(str.upper, ['a', 'b', 'c'])
152
- ['A', 'B', 'C']
153
- """
154
- args = list(map(list, args))
155
- n = len(args[0])
156
- for arg in args[1:]:
157
- assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
158
- return list(map(f, *args))
159
-
160
-
161
- def safe_zip(*args):
162
- """
163
- Zip multiple sequences with length checking.
164
-
165
- Combines corresponding elements from multiple sequences into tuples,
166
- ensuring all sequences have the same length.
167
-
168
- Args:
169
- *args: Variable number of sequences to zip together.
170
-
171
- Returns:
172
- list: List of tuples containing corresponding elements.
173
-
174
- Raises:
175
- AssertionError: If input sequences have different lengths.
176
-
177
- Examples:
178
- >>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
179
- [(1, 'a'), (2, 'b'), (3, 'c')]
180
-
181
- >>> safe_zip([1, 2], [3, 4], [5, 6])
182
- [(1, 3, 5), (2, 4, 6)]
183
- """
184
- args = list(map(list, args))
185
- n = len(args[0])
186
- for arg in args[1:]:
187
- assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
188
- return list(zip(*args))
189
-
190
-
191
- def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
192
- """
193
- Unzip sequence of length-2 tuples into two tuples.
194
-
195
- Takes an iterable of 2-tuples and separates them into two tuples
196
- containing the first and second elements respectively.
197
-
198
- Args:
199
- xys: Iterable of 2-tuples to unzip.
200
-
201
- Returns:
202
- tuple: A 2-tuple containing:
203
- - Tuple of all first elements
204
- - Tuple of all second elements
205
-
206
- Examples:
207
- >>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
208
- >>> nums, letters = unzip2(pairs)
209
- >>> nums
210
- (1, 2, 3)
211
- >>> letters
212
- ('a', 'b', 'c')
213
-
214
- Notes:
215
- We deliberately don't use zip(*xys) because it is lazily evaluated,
216
- is too permissive about inputs, and does not guarantee a length-2 output.
217
- """
218
- # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
219
- # is too permissive about inputs, and does not guarantee a length-2 output.
220
- xs: list[T1] = []
221
- ys: list[T2] = []
222
- for x, y in xys:
223
- xs.append(x)
224
- ys.append(y)
225
- return tuple(xs), tuple(ys)
226
-
227
-
228
- def fun_name(fun: Callable):
229
- """
230
- Extract the name of a function, handling special cases.
231
-
232
- Attempts to get the name of a function, with special handling for
233
- partial functions and fallback for unnamed functions.
234
-
235
- Args:
236
- fun: The function to get the name from.
237
-
238
- Returns:
239
- str: The function name, or "<unnamed function>" if no name available.
240
-
241
- Examples:
242
- >>> def my_function():
243
- ... pass
244
- >>> fun_name(my_function)
245
- 'my_function'
246
-
247
- >>> from functools import partial
248
- >>> add = lambda x, y: x + y
249
- >>> add_one = partial(add, 1)
250
- >>> fun_name(add_one)
251
- '<lambda>'
252
- """
253
- name = getattr(fun, "__name__", None)
254
- if name is not None:
255
- return name
256
- if isinstance(fun, partial):
257
- return fun_name(fun.func)
258
- else:
259
- return "<unnamed function>"
260
-
261
-
262
- def wraps(
263
- wrapped: Callable,
264
- namestr: str | None = None,
265
- docstr: str | None = None,
266
- **kwargs,
267
- ) -> Callable[[T], T]:
268
- """
269
- Enhanced function wrapper with fine-grained control.
270
-
271
- Like functools.wraps, but provides more control over the name and docstring
272
- of the resulting function. Useful for creating custom decorators.
273
-
274
- Args:
275
- wrapped: The function being wrapped.
276
- namestr: Optional format string for the wrapper function name.
277
- Can use {fun} placeholder for the original function name.
278
- docstr: Optional format string for the wrapper function docstring.
279
- Can use {fun}, {doc}, and other kwargs as placeholders.
280
- **kwargs: Additional keyword arguments for format string substitution.
281
-
282
- Returns:
283
- Callable: A decorator function that applies the wrapping.
284
-
285
- Examples:
286
- >>> def my_decorator(func):
287
- ... @wraps(func, namestr="decorated_{fun}")
288
- ... def wrapper(*args, **kwargs):
289
- ... return func(*args, **kwargs)
290
- ... return wrapper
291
-
292
- >>> @my_decorator
293
- ... def example():
294
- ... pass
295
- >>> example.__name__
296
- 'decorated_example'
297
- """
298
-
299
- def wrapper(fun: T) -> T:
300
- try:
301
- name = fun_name(wrapped)
302
- doc = getattr(wrapped, "__doc__", "") or ""
303
- fun.__dict__.update(getattr(wrapped, "__dict__", {}))
304
- fun.__annotations__ = getattr(wrapped, "__annotations__", {})
305
- fun.__name__ = name if namestr is None else namestr.format(fun=name)
306
- fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
307
- fun.__doc__ = (doc if docstr is None
308
- else docstr.format(fun=name, doc=doc, **kwargs))
309
- fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
310
- fun.__wrapped__ = wrapped
311
- except Exception:
312
- pass
313
- return fun
314
-
315
- return wrapper
316
-
317
-
318
- def to_concrete_aval(aval):
319
- """
320
- Convert an abstract value to its concrete representation.
321
-
322
- Takes an abstract value and attempts to convert it to a concrete value,
323
- handling JAX Tracer objects appropriately.
324
-
325
- Args:
326
- aval: The abstract value to convert.
327
-
328
- Returns:
329
- The concrete value representation, or the original aval if already concrete.
330
-
331
- Examples:
332
- >>> import jax.numpy as jnp
333
- >>> arr = jnp.array([1, 2, 3])
334
- >>> concrete = to_concrete_aval(arr)
335
- # Returns the concrete array value
336
- """
337
- aval = get_aval(aval)
338
- if isinstance(aval, Tracer):
339
- return aval.to_concrete_value()
340
- return aval
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ """
19
+ Compatibility layer for JAX version differences.
20
+
21
+ This module provides a compatibility layer to handle differences between various
22
+ versions of JAX, ensuring that BrainState works correctly across different JAX
23
+ versions. It imports the appropriate modules and functions based on the detected
24
+ JAX version and provides fallback implementations when necessary.
25
+
26
+ Key Features:
27
+ - Version-aware imports for JAX core functionality
28
+ - Compatibility wrappers for changed APIs
29
+ - Fallback implementations for deprecated functions
30
+ - Type-safe utility functions
31
+
32
+ Examples:
33
+ Basic usage:
34
+
35
+ >>> from brainstate._compatible_import import safe_map, safe_zip
36
+ >>> result = safe_map(lambda x: x * 2, [1, 2, 3])
37
+ >>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
38
+
39
+ Using JAX core types:
40
+
41
+ >>> from brainstate._compatible_import import Primitive, ClosedJaxpr
42
+ >>> # These imports work across different JAX versions
43
+ """
44
+
45
+ from contextlib import contextmanager
46
+ from functools import partial
47
+ from typing import Iterable, Hashable, TypeVar, Callable
48
+
49
+ import jax
50
+ from jax.core import get_aval, Tracer
51
+ from saiunit._compatible_import import wrap_init
52
+
53
+ __all__ = [
54
+ 'ClosedJaxpr',
55
+ 'Primitive',
56
+ 'extend_axis_env_nd',
57
+ 'jaxpr_as_fun',
58
+ 'get_aval',
59
+ 'Tracer',
60
+ 'to_concrete_aval',
61
+ 'safe_map',
62
+ 'safe_zip',
63
+ 'unzip2',
64
+ 'wraps',
65
+ 'Device',
66
+ 'wrap_init',
67
+ 'Var',
68
+ 'JaxprEqn',
69
+ 'Jaxpr',
70
+ 'Literal',
71
+
72
+ 'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
73
+ ]
74
+
75
+ T = TypeVar("T")
76
+ T1 = TypeVar("T1")
77
+ T2 = TypeVar("T2")
78
+ T3 = TypeVar("T3")
79
+
80
+ if jax.__version_info__ < (0, 5, 0):
81
+ from jax.lib.xla_client import Device
82
+ else:
83
+ from jax import Device
84
+
85
+ if jax.__version_info__ < (0, 7, 1):
86
+ from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
87
+ else:
88
+ from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
89
+
90
+ if jax.__version_info__ < (0, 4, 38):
91
+ from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
92
+ from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
93
+ else:
94
+ from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
95
+ from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
96
+ from jax.core import trace_ctx
97
+
98
+
99
+ @contextmanager
100
+ def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
101
+ """
102
+ Context manager to temporarily extend the JAX axis environment.
103
+
104
+ Extends the current JAX axis environment with new named axes for
105
+ vectorized computations, then restores the previous environment.
106
+
107
+ Args:
108
+ name_size_pairs: Iterable of (name, size) tuples specifying
109
+ the named axes to add to the environment.
110
+
111
+ Yields:
112
+ None: Context with extended axis environment.
113
+
114
+ Examples:
115
+ >>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
116
+ ... # Code using vectorized operations with named axes
117
+ ... pass
118
+ """
119
+ prev = trace_ctx.axis_env
120
+ try:
121
+ trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
122
+ yield
123
+ finally:
124
+ trace_ctx.set_axis_env(prev)
125
+
126
+ if jax.__version_info__ < (0, 6, 0):
127
+ from jax.util import safe_map, safe_zip, unzip2, wraps
128
+
129
+ else:
130
+ def safe_map(f, *args):
131
+ """
132
+ Map a function over multiple sequences with length checking.
133
+
134
+ Applies a function to corresponding elements from multiple sequences,
135
+ ensuring all sequences have the same length.
136
+
137
+ Args:
138
+ f: Function to apply to elements from each sequence.
139
+ *args: Variable number of sequences to map over.
140
+
141
+ Returns:
142
+ list: Results of applying f to corresponding elements.
143
+
144
+ Raises:
145
+ AssertionError: If input sequences have different lengths.
146
+
147
+ Examples:
148
+ >>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
149
+ [5, 7, 9]
150
+
151
+ >>> safe_map(str.upper, ['a', 'b', 'c'])
152
+ ['A', 'B', 'C']
153
+ """
154
+ args = list(map(list, args))
155
+ n = len(args[0])
156
+ for arg in args[1:]:
157
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
158
+ return list(map(f, *args))
159
+
160
+
161
+ def safe_zip(*args):
162
+ """
163
+ Zip multiple sequences with length checking.
164
+
165
+ Combines corresponding elements from multiple sequences into tuples,
166
+ ensuring all sequences have the same length.
167
+
168
+ Args:
169
+ *args: Variable number of sequences to zip together.
170
+
171
+ Returns:
172
+ list: List of tuples containing corresponding elements.
173
+
174
+ Raises:
175
+ AssertionError: If input sequences have different lengths.
176
+
177
+ Examples:
178
+ >>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
179
+ [(1, 'a'), (2, 'b'), (3, 'c')]
180
+
181
+ >>> safe_zip([1, 2], [3, 4], [5, 6])
182
+ [(1, 3, 5), (2, 4, 6)]
183
+ """
184
+ args = list(map(list, args))
185
+ n = len(args[0])
186
+ for arg in args[1:]:
187
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
188
+ return list(zip(*args))
189
+
190
+
191
+ def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
192
+ """
193
+ Unzip sequence of length-2 tuples into two tuples.
194
+
195
+ Takes an iterable of 2-tuples and separates them into two tuples
196
+ containing the first and second elements respectively.
197
+
198
+ Args:
199
+ xys: Iterable of 2-tuples to unzip.
200
+
201
+ Returns:
202
+ tuple: A 2-tuple containing:
203
+ - Tuple of all first elements
204
+ - Tuple of all second elements
205
+
206
+ Examples:
207
+ >>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
208
+ >>> nums, letters = unzip2(pairs)
209
+ >>> nums
210
+ (1, 2, 3)
211
+ >>> letters
212
+ ('a', 'b', 'c')
213
+
214
+ Notes:
215
+ We deliberately don't use zip(*xys) because it is lazily evaluated,
216
+ is too permissive about inputs, and does not guarantee a length-2 output.
217
+ """
218
+ # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
219
+ # is too permissive about inputs, and does not guarantee a length-2 output.
220
+ xs: list[T1] = []
221
+ ys: list[T2] = []
222
+ for x, y in xys:
223
+ xs.append(x)
224
+ ys.append(y)
225
+ return tuple(xs), tuple(ys)
226
+
227
+
228
+ def fun_name(fun: Callable):
229
+ """
230
+ Extract the name of a function, handling special cases.
231
+
232
+ Attempts to get the name of a function, with special handling for
233
+ partial functions and fallback for unnamed functions.
234
+
235
+ Args:
236
+ fun: The function to get the name from.
237
+
238
+ Returns:
239
+ str: The function name, or "<unnamed function>" if no name available.
240
+
241
+ Examples:
242
+ >>> def my_function():
243
+ ... pass
244
+ >>> fun_name(my_function)
245
+ 'my_function'
246
+
247
+ >>> from functools import partial
248
+ >>> add = lambda x, y: x + y
249
+ >>> add_one = partial(add, 1)
250
+ >>> fun_name(add_one)
251
+ '<lambda>'
252
+ """
253
+ name = getattr(fun, "__name__", None)
254
+ if name is not None:
255
+ return name
256
+ if isinstance(fun, partial):
257
+ return fun_name(fun.func)
258
+ else:
259
+ return "<unnamed function>"
260
+
261
+
262
+ def wraps(
263
+ wrapped: Callable,
264
+ namestr: str | None = None,
265
+ docstr: str | None = None,
266
+ **kwargs,
267
+ ) -> Callable[[T], T]:
268
+ """
269
+ Enhanced function wrapper with fine-grained control.
270
+
271
+ Like functools.wraps, but provides more control over the name and docstring
272
+ of the resulting function. Useful for creating custom decorators.
273
+
274
+ Args:
275
+ wrapped: The function being wrapped.
276
+ namestr: Optional format string for the wrapper function name.
277
+ Can use {fun} placeholder for the original function name.
278
+ docstr: Optional format string for the wrapper function docstring.
279
+ Can use {fun}, {doc}, and other kwargs as placeholders.
280
+ **kwargs: Additional keyword arguments for format string substitution.
281
+
282
+ Returns:
283
+ Callable: A decorator function that applies the wrapping.
284
+
285
+ Examples:
286
+ >>> def my_decorator(func):
287
+ ... @wraps(func, namestr="decorated_{fun}")
288
+ ... def wrapper(*args, **kwargs):
289
+ ... return func(*args, **kwargs)
290
+ ... return wrapper
291
+
292
+ >>> @my_decorator
293
+ ... def example():
294
+ ... pass
295
+ >>> example.__name__
296
+ 'decorated_example'
297
+ """
298
+
299
+ def wrapper(fun: T) -> T:
300
+ try:
301
+ name = fun_name(wrapped)
302
+ doc = getattr(wrapped, "__doc__", "") or ""
303
+ fun.__dict__.update(getattr(wrapped, "__dict__", {}))
304
+ fun.__annotations__ = getattr(wrapped, "__annotations__", {})
305
+ fun.__name__ = name if namestr is None else namestr.format(fun=name)
306
+ fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
307
+ fun.__doc__ = (doc if docstr is None
308
+ else docstr.format(fun=name, doc=doc, **kwargs))
309
+ fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
310
+ fun.__wrapped__ = wrapped
311
+ except Exception:
312
+ pass
313
+ return fun
314
+
315
+ return wrapper
316
+
317
+
318
+ def to_concrete_aval(aval):
319
+ """
320
+ Convert an abstract value to its concrete representation.
321
+
322
+ Takes an abstract value and attempts to convert it to a concrete value,
323
+ handling JAX Tracer objects appropriately.
324
+
325
+ Args:
326
+ aval: The abstract value to convert.
327
+
328
+ Returns:
329
+ The concrete value representation, or the original aval if already concrete.
330
+
331
+ Examples:
332
+ >>> import jax.numpy as jnp
333
+ >>> arr = jnp.array([1, 2, 3])
334
+ >>> concrete = to_concrete_aval(arr)
335
+ # Returns the concrete array value
336
+ """
337
+ aval = get_aval(aval)
338
+ if isinstance(aval, Tracer):
339
+ return aval.to_concrete_value()
340
+ return aval