brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -17,42 +17,153 @@
|
|
17
17
|
A ``State``-based Transformation System for Program Compilation and Augmentation
|
18
18
|
"""
|
19
19
|
|
20
|
-
__version__ = "0.
|
20
|
+
__version__ = "0.2.0"
|
21
|
+
__versio_info__ = (0, 2, 0)
|
21
22
|
|
22
|
-
from . import augment
|
23
|
-
from . import compile
|
24
23
|
from . import environ
|
25
|
-
from . import functional
|
26
24
|
from . import graph
|
27
|
-
from . import init
|
28
25
|
from . import mixin
|
29
26
|
from . import nn
|
30
|
-
from . import optim
|
31
27
|
from . import random
|
32
|
-
from . import surrogate
|
33
28
|
from . import transform
|
34
29
|
from . import typing
|
35
30
|
from . import util
|
31
|
+
from ._error import *
|
32
|
+
from ._error import __all__ as _error_all
|
36
33
|
from ._state import *
|
37
34
|
from ._state import __all__ as _state_all
|
38
35
|
|
36
|
+
# Create deprecated module proxies with scoped APIs
|
37
|
+
from ._deprecation import create_deprecated_module_proxy
|
38
|
+
|
39
|
+
# Augment module scope
|
40
|
+
_augment_apis = {
|
41
|
+
'GradientTransform': 'brainstate.transform._autograd',
|
42
|
+
'grad': 'brainstate.transform._autograd',
|
43
|
+
'vector_grad': 'brainstate.transform._autograd',
|
44
|
+
'hessian': 'brainstate.transform._autograd',
|
45
|
+
'jacobian': 'brainstate.transform._autograd',
|
46
|
+
'jacrev': 'brainstate.transform._autograd',
|
47
|
+
'jacfwd': 'brainstate.transform._autograd',
|
48
|
+
'abstract_init': 'brainstate.transform._eval_shape',
|
49
|
+
'vmap': 'brainstate.transform._mapping',
|
50
|
+
'pmap': 'brainstate.transform._mapping',
|
51
|
+
'map': 'brainstate.transform._mapping',
|
52
|
+
'vmap_new_states': 'brainstate.transform._mapping',
|
53
|
+
'restore_rngs': 'brainstate.transform._random',
|
54
|
+
}
|
55
|
+
|
56
|
+
augment = create_deprecated_module_proxy(
|
57
|
+
deprecated_name='brainstate.augment',
|
58
|
+
replacement_module=transform,
|
59
|
+
replacement_name='brainstate.transform',
|
60
|
+
scoped_apis=_augment_apis
|
61
|
+
)
|
62
|
+
|
63
|
+
# Compile module scope
|
64
|
+
_compile_apis = {
|
65
|
+
'checkpoint': 'brainstate.transform._ad_checkpoint',
|
66
|
+
'remat': 'brainstate.transform._ad_checkpoint',
|
67
|
+
'cond': 'brainstate.transform._conditions',
|
68
|
+
'switch': 'brainstate.transform._conditions',
|
69
|
+
'ifelse': 'brainstate.transform._conditions',
|
70
|
+
'jit_error_if': 'brainstate.transform._error_if',
|
71
|
+
'jit': 'brainstate.transform._jit',
|
72
|
+
'scan': 'brainstate.transform._loop_collect_return',
|
73
|
+
'checkpointed_scan': 'brainstate.transform._loop_collect_return',
|
74
|
+
'for_loop': 'brainstate.transform._loop_collect_return',
|
75
|
+
'checkpointed_for_loop': 'brainstate.transform._loop_collect_return',
|
76
|
+
'while_loop': 'brainstate.transform._loop_no_collection',
|
77
|
+
'bounded_while_loop': 'brainstate.transform._loop_no_collection',
|
78
|
+
'StatefulFunction': 'brainstate.transform._make_jaxpr',
|
79
|
+
'make_jaxpr': 'brainstate.transform._make_jaxpr',
|
80
|
+
'ProgressBar': 'brainstate.transform._progress_bar',
|
81
|
+
}
|
82
|
+
|
83
|
+
compile = create_deprecated_module_proxy(
|
84
|
+
deprecated_name='brainstate.compile',
|
85
|
+
replacement_module=transform,
|
86
|
+
replacement_name='brainstate.transform',
|
87
|
+
scoped_apis=_compile_apis
|
88
|
+
)
|
89
|
+
|
90
|
+
# Functional module scope - use direct attribute access from nn module
|
91
|
+
_functional_apis = {
|
92
|
+
'weight_standardization': 'brainstate.nn._normalizations',
|
93
|
+
'clip_grad_norm': 'brainstate.nn._others',
|
94
|
+
'tanh': 'brainstate.nn._activations',
|
95
|
+
'relu': 'brainstate.nn._activations',
|
96
|
+
'squareplus': 'brainstate.nn._activations',
|
97
|
+
'softplus': 'brainstate.nn._activations',
|
98
|
+
'soft_sign': 'brainstate.nn._activations',
|
99
|
+
'sigmoid': 'brainstate.nn._activations',
|
100
|
+
'silu': 'brainstate.nn._activations',
|
101
|
+
'swish': 'brainstate.nn._activations',
|
102
|
+
'log_sigmoid': 'brainstate.nn._activations',
|
103
|
+
'elu': 'brainstate.nn._activations',
|
104
|
+
'leaky_relu': 'brainstate.nn._activations',
|
105
|
+
'hard_tanh': 'brainstate.nn._activations',
|
106
|
+
'celu': 'brainstate.nn._activations',
|
107
|
+
'selu': 'brainstate.nn._activations',
|
108
|
+
'gelu': 'brainstate.nn._activations',
|
109
|
+
'glu': 'brainstate.nn._activations',
|
110
|
+
'logsumexp': 'brainstate.nn._activations',
|
111
|
+
'log_softmax': 'brainstate.nn._activations',
|
112
|
+
'softmax': 'brainstate.nn._activations',
|
113
|
+
'standardize': 'brainstate.nn._activations',
|
114
|
+
'relu6': 'brainstate.nn._activations',
|
115
|
+
'hard_sigmoid': 'brainstate.nn._activations',
|
116
|
+
'sparse_plus': 'brainstate.nn._activations',
|
117
|
+
'hard_silu': 'brainstate.nn._activations',
|
118
|
+
'hard_swish': 'brainstate.nn._activations',
|
119
|
+
'hard_shrink': 'brainstate.nn._activations',
|
120
|
+
'rrelu': 'brainstate.nn._activations',
|
121
|
+
'mish': 'brainstate.nn._activations',
|
122
|
+
'soft_shrink': 'brainstate.nn._activations',
|
123
|
+
'prelu': 'brainstate.nn._activations',
|
124
|
+
'softmin': 'brainstate.nn._activations',
|
125
|
+
'one_hot': 'brainstate.nn._activations',
|
126
|
+
'sparse_sigmoid': 'brainstate.nn._activations',
|
127
|
+
}
|
128
|
+
|
129
|
+
functional = create_deprecated_module_proxy(
|
130
|
+
deprecated_name='brainstate.functional',
|
131
|
+
replacement_module=nn,
|
132
|
+
replacement_name='brainstate.nn',
|
133
|
+
scoped_apis=_functional_apis
|
134
|
+
)
|
135
|
+
|
136
|
+
|
137
|
+
def __getattr__(name):
|
138
|
+
if name in ['surrogate', 'init', 'optim']:
|
139
|
+
import warnings
|
140
|
+
warnings.warn(
|
141
|
+
f"brainstate.{name} module is deprecated and will be removed in a future version. "
|
142
|
+
f"Please use braintools.{name} instead.",
|
143
|
+
DeprecationWarning,
|
144
|
+
stacklevel=2
|
145
|
+
)
|
146
|
+
import braintools
|
147
|
+
return getattr(braintools, name)
|
148
|
+
raise AttributeError(
|
149
|
+
f'module {__name__!r} has no attribute {name!r}'
|
150
|
+
)
|
151
|
+
|
152
|
+
|
39
153
|
__all__ = [
|
40
|
-
'augment',
|
41
|
-
'compile',
|
42
154
|
'environ',
|
43
|
-
'functional',
|
44
155
|
'graph',
|
45
|
-
'init',
|
46
156
|
'mixin',
|
47
157
|
'nn',
|
48
|
-
'optim',
|
49
158
|
'random',
|
50
|
-
'
|
159
|
+
'transform',
|
51
160
|
'typing',
|
52
161
|
'util',
|
53
|
-
|
162
|
+
# Deprecated modules
|
163
|
+
'augment',
|
164
|
+
'compile',
|
165
|
+
'functional',
|
54
166
|
]
|
55
|
-
__all__ = __all__ + _state_all
|
56
|
-
|
57
|
-
|
58
|
-
del _state_all
|
167
|
+
__all__ = __all__ + _state_all + _error_all
|
168
|
+
del _state_all, create_deprecated_module_proxy, _augment_apis, _compile_apis, _functional_apis
|
169
|
+
del _error_all
|
brainstate/_compatible_import.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2025
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -15,12 +15,40 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
+
"""
|
19
|
+
Compatibility layer for JAX version differences.
|
20
|
+
|
21
|
+
This module provides a compatibility layer to handle differences between various
|
22
|
+
versions of JAX, ensuring that BrainState works correctly across different JAX
|
23
|
+
versions. It imports the appropriate modules and functions based on the detected
|
24
|
+
JAX version and provides fallback implementations when necessary.
|
25
|
+
|
26
|
+
Key Features:
|
27
|
+
- Version-aware imports for JAX core functionality
|
28
|
+
- Compatibility wrappers for changed APIs
|
29
|
+
- Fallback implementations for deprecated functions
|
30
|
+
- Type-safe utility functions
|
31
|
+
|
32
|
+
Examples:
|
33
|
+
Basic usage:
|
34
|
+
|
35
|
+
>>> from brainstate._compatible_import import safe_map, safe_zip
|
36
|
+
>>> result = safe_map(lambda x: x * 2, [1, 2, 3])
|
37
|
+
>>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
|
38
|
+
|
39
|
+
Using JAX core types:
|
40
|
+
|
41
|
+
>>> from brainstate._compatible_import import Primitive, ClosedJaxpr
|
42
|
+
>>> # These imports work across different JAX versions
|
43
|
+
"""
|
18
44
|
|
19
45
|
from contextlib import contextmanager
|
20
46
|
from functools import partial
|
21
47
|
from typing import Iterable, Hashable, TypeVar, Callable
|
22
48
|
|
23
49
|
import jax
|
50
|
+
from jax.core import get_aval, Tracer
|
51
|
+
from saiunit._compatible_import import wrap_init
|
24
52
|
|
25
53
|
__all__ = [
|
26
54
|
'ClosedJaxpr',
|
@@ -36,6 +64,12 @@ __all__ = [
|
|
36
64
|
'wraps',
|
37
65
|
'Device',
|
38
66
|
'wrap_init',
|
67
|
+
'Var',
|
68
|
+
'JaxprEqn',
|
69
|
+
'Jaxpr',
|
70
|
+
'Literal',
|
71
|
+
|
72
|
+
'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
|
39
73
|
]
|
40
74
|
|
41
75
|
T = TypeVar("T")
|
@@ -43,24 +77,45 @@ T1 = TypeVar("T1")
|
|
43
77
|
T2 = TypeVar("T2")
|
44
78
|
T3 = TypeVar("T3")
|
45
79
|
|
46
|
-
from saiunit._compatible_import import wrap_init
|
47
|
-
|
48
|
-
from jax.core import get_aval, Tracer
|
49
|
-
|
50
80
|
if jax.__version_info__ < (0, 5, 0):
|
51
81
|
from jax.lib.xla_client import Device
|
52
82
|
else:
|
53
83
|
from jax import Device
|
54
84
|
|
85
|
+
if jax.__version_info__ < (0, 7, 1):
|
86
|
+
from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
|
87
|
+
else:
|
88
|
+
from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
|
89
|
+
|
55
90
|
if jax.__version_info__ < (0, 4, 38):
|
56
91
|
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
92
|
+
from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
57
93
|
else:
|
58
94
|
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
|
95
|
+
from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
59
96
|
from jax.core import trace_ctx
|
60
97
|
|
61
98
|
|
62
99
|
@contextmanager
|
63
100
|
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
|
101
|
+
"""
|
102
|
+
Context manager to temporarily extend the JAX axis environment.
|
103
|
+
|
104
|
+
Extends the current JAX axis environment with new named axes for
|
105
|
+
vectorized computations, then restores the previous environment.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
name_size_pairs: Iterable of (name, size) tuples specifying
|
109
|
+
the named axes to add to the environment.
|
110
|
+
|
111
|
+
Yields:
|
112
|
+
None: Context with extended axis environment.
|
113
|
+
|
114
|
+
Examples:
|
115
|
+
>>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
|
116
|
+
... # Code using vectorized operations with named axes
|
117
|
+
... pass
|
118
|
+
"""
|
64
119
|
prev = trace_ctx.axis_env
|
65
120
|
try:
|
66
121
|
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
|
@@ -73,6 +128,29 @@ if jax.__version_info__ < (0, 6, 0):
|
|
73
128
|
|
74
129
|
else:
|
75
130
|
def safe_map(f, *args):
|
131
|
+
"""
|
132
|
+
Map a function over multiple sequences with length checking.
|
133
|
+
|
134
|
+
Applies a function to corresponding elements from multiple sequences,
|
135
|
+
ensuring all sequences have the same length.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
f: Function to apply to elements from each sequence.
|
139
|
+
*args: Variable number of sequences to map over.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
list: Results of applying f to corresponding elements.
|
143
|
+
|
144
|
+
Raises:
|
145
|
+
AssertionError: If input sequences have different lengths.
|
146
|
+
|
147
|
+
Examples:
|
148
|
+
>>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
|
149
|
+
[5, 7, 9]
|
150
|
+
|
151
|
+
>>> safe_map(str.upper, ['a', 'b', 'c'])
|
152
|
+
['A', 'B', 'C']
|
153
|
+
"""
|
76
154
|
args = list(map(list, args))
|
77
155
|
n = len(args[0])
|
78
156
|
for arg in args[1:]:
|
@@ -81,6 +159,28 @@ else:
|
|
81
159
|
|
82
160
|
|
83
161
|
def safe_zip(*args):
|
162
|
+
"""
|
163
|
+
Zip multiple sequences with length checking.
|
164
|
+
|
165
|
+
Combines corresponding elements from multiple sequences into tuples,
|
166
|
+
ensuring all sequences have the same length.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
*args: Variable number of sequences to zip together.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
list: List of tuples containing corresponding elements.
|
173
|
+
|
174
|
+
Raises:
|
175
|
+
AssertionError: If input sequences have different lengths.
|
176
|
+
|
177
|
+
Examples:
|
178
|
+
>>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
|
179
|
+
[(1, 'a'), (2, 'b'), (3, 'c')]
|
180
|
+
|
181
|
+
>>> safe_zip([1, 2], [3, 4], [5, 6])
|
182
|
+
[(1, 3, 5), (2, 4, 6)]
|
183
|
+
"""
|
84
184
|
args = list(map(list, args))
|
85
185
|
n = len(args[0])
|
86
186
|
for arg in args[1:]:
|
@@ -89,7 +189,32 @@ else:
|
|
89
189
|
|
90
190
|
|
91
191
|
def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
92
|
-
"""
|
192
|
+
"""
|
193
|
+
Unzip sequence of length-2 tuples into two tuples.
|
194
|
+
|
195
|
+
Takes an iterable of 2-tuples and separates them into two tuples
|
196
|
+
containing the first and second elements respectively.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
xys: Iterable of 2-tuples to unzip.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
tuple: A 2-tuple containing:
|
203
|
+
- Tuple of all first elements
|
204
|
+
- Tuple of all second elements
|
205
|
+
|
206
|
+
Examples:
|
207
|
+
>>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
|
208
|
+
>>> nums, letters = unzip2(pairs)
|
209
|
+
>>> nums
|
210
|
+
(1, 2, 3)
|
211
|
+
>>> letters
|
212
|
+
('a', 'b', 'c')
|
213
|
+
|
214
|
+
Notes:
|
215
|
+
We deliberately don't use zip(*xys) because it is lazily evaluated,
|
216
|
+
is too permissive about inputs, and does not guarantee a length-2 output.
|
217
|
+
"""
|
93
218
|
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
94
219
|
# is too permissive about inputs, and does not guarantee a length-2 output.
|
95
220
|
xs: list[T1] = []
|
@@ -101,6 +226,30 @@ else:
|
|
101
226
|
|
102
227
|
|
103
228
|
def fun_name(fun: Callable):
|
229
|
+
"""
|
230
|
+
Extract the name of a function, handling special cases.
|
231
|
+
|
232
|
+
Attempts to get the name of a function, with special handling for
|
233
|
+
partial functions and fallback for unnamed functions.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
fun: The function to get the name from.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
str: The function name, or "<unnamed function>" if no name available.
|
240
|
+
|
241
|
+
Examples:
|
242
|
+
>>> def my_function():
|
243
|
+
... pass
|
244
|
+
>>> fun_name(my_function)
|
245
|
+
'my_function'
|
246
|
+
|
247
|
+
>>> from functools import partial
|
248
|
+
>>> add = lambda x, y: x + y
|
249
|
+
>>> add_one = partial(add, 1)
|
250
|
+
>>> fun_name(add_one)
|
251
|
+
'<lambda>'
|
252
|
+
"""
|
104
253
|
name = getattr(fun, "__name__", None)
|
105
254
|
if name is not None:
|
106
255
|
return name
|
@@ -117,8 +266,34 @@ else:
|
|
117
266
|
**kwargs,
|
118
267
|
) -> Callable[[T], T]:
|
119
268
|
"""
|
120
|
-
|
121
|
-
|
269
|
+
Enhanced function wrapper with fine-grained control.
|
270
|
+
|
271
|
+
Like functools.wraps, but provides more control over the name and docstring
|
272
|
+
of the resulting function. Useful for creating custom decorators.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
wrapped: The function being wrapped.
|
276
|
+
namestr: Optional format string for the wrapper function name.
|
277
|
+
Can use {fun} placeholder for the original function name.
|
278
|
+
docstr: Optional format string for the wrapper function docstring.
|
279
|
+
Can use {fun}, {doc}, and other kwargs as placeholders.
|
280
|
+
**kwargs: Additional keyword arguments for format string substitution.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
Callable: A decorator function that applies the wrapping.
|
284
|
+
|
285
|
+
Examples:
|
286
|
+
>>> def my_decorator(func):
|
287
|
+
... @wraps(func, namestr="decorated_{fun}")
|
288
|
+
... def wrapper(*args, **kwargs):
|
289
|
+
... return func(*args, **kwargs)
|
290
|
+
... return wrapper
|
291
|
+
|
292
|
+
>>> @my_decorator
|
293
|
+
... def example():
|
294
|
+
... pass
|
295
|
+
>>> example.__name__
|
296
|
+
'decorated_example'
|
122
297
|
"""
|
123
298
|
|
124
299
|
def wrapper(fun: T) -> T:
|
@@ -141,8 +316,25 @@ else:
|
|
141
316
|
|
142
317
|
|
143
318
|
def to_concrete_aval(aval):
|
319
|
+
"""
|
320
|
+
Convert an abstract value to its concrete representation.
|
321
|
+
|
322
|
+
Takes an abstract value and attempts to convert it to a concrete value,
|
323
|
+
handling JAX Tracer objects appropriately.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
aval: The abstract value to convert.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
The concrete value representation, or the original aval if already concrete.
|
330
|
+
|
331
|
+
Examples:
|
332
|
+
>>> import jax.numpy as jnp
|
333
|
+
>>> arr = jnp.array([1, 2, 3])
|
334
|
+
>>> concrete = to_concrete_aval(arr)
|
335
|
+
# Returns the concrete array value
|
336
|
+
"""
|
144
337
|
aval = get_aval(aval)
|
145
338
|
if isinstance(aval, Tracer):
|
146
339
|
return aval.to_concrete_value()
|
147
340
|
return aval
|
148
|
-
|