brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -1,51 +1,58 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
A ``State``-based Transformation System for Program Compilation and Augmentation
|
18
|
-
"""
|
19
|
-
|
20
|
-
__version__ = "0.1.
|
21
|
-
|
22
|
-
from . import augment
|
23
|
-
from . import compile
|
24
|
-
from . import environ
|
25
|
-
from . import functional
|
26
|
-
from . import graph
|
27
|
-
from . import init
|
28
|
-
from . import mixin
|
29
|
-
from . import nn
|
30
|
-
from . import optim
|
31
|
-
from . import random
|
32
|
-
from . import surrogate
|
33
|
-
from . import transform
|
34
|
-
from . import typing
|
35
|
-
from . import util
|
36
|
-
from ._state import *
|
37
|
-
from ._state import __all__ as _state_all
|
38
|
-
|
39
|
-
__all__ =
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
A ``State``-based Transformation System for Program Compilation and Augmentation
|
18
|
+
"""
|
19
|
+
|
20
|
+
__version__ = "0.1.9"
|
21
|
+
|
22
|
+
from . import augment
|
23
|
+
from . import compile
|
24
|
+
from . import environ
|
25
|
+
from . import functional
|
26
|
+
from . import graph
|
27
|
+
from . import init
|
28
|
+
from . import mixin
|
29
|
+
from . import nn
|
30
|
+
from . import optim
|
31
|
+
from . import random
|
32
|
+
from . import surrogate
|
33
|
+
from . import transform
|
34
|
+
from . import typing
|
35
|
+
from . import util
|
36
|
+
from ._state import *
|
37
|
+
from ._state import __all__ as _state_all
|
38
|
+
|
39
|
+
__all__ = [
|
40
|
+
'augment',
|
41
|
+
'compile',
|
42
|
+
'environ',
|
43
|
+
'functional',
|
44
|
+
'graph',
|
45
|
+
'init',
|
46
|
+
'mixin',
|
47
|
+
'nn',
|
48
|
+
'optim',
|
49
|
+
'random',
|
50
|
+
'surrogate',
|
51
|
+
'typing',
|
52
|
+
'util',
|
53
|
+
'transform',
|
54
|
+
]
|
55
|
+
__all__ = __all__ + _state_all
|
56
|
+
|
57
|
+
# ----------------------- #
|
58
|
+
del _state_all
|
brainstate/_compatible_import.py
CHANGED
@@ -1,148 +1,148 @@
|
|
1
|
-
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
from contextlib import contextmanager
|
20
|
-
from functools import partial
|
21
|
-
from typing import Iterable, Hashable, TypeVar, Callable
|
22
|
-
|
23
|
-
import jax
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'ClosedJaxpr',
|
27
|
-
'Primitive',
|
28
|
-
'extend_axis_env_nd',
|
29
|
-
'jaxpr_as_fun',
|
30
|
-
'get_aval',
|
31
|
-
'Tracer',
|
32
|
-
'to_concrete_aval',
|
33
|
-
'safe_map',
|
34
|
-
'safe_zip',
|
35
|
-
'unzip2',
|
36
|
-
'wraps',
|
37
|
-
'Device',
|
38
|
-
'wrap_init',
|
39
|
-
]
|
40
|
-
|
41
|
-
T = TypeVar("T")
|
42
|
-
T1 = TypeVar("T1")
|
43
|
-
T2 = TypeVar("T2")
|
44
|
-
T3 = TypeVar("T3")
|
45
|
-
|
46
|
-
from saiunit._compatible_import import wrap_init
|
47
|
-
|
48
|
-
from jax.core import get_aval, Tracer
|
49
|
-
|
50
|
-
if jax.__version_info__ < (0, 5, 0):
|
51
|
-
from jax.lib.xla_client import Device
|
52
|
-
else:
|
53
|
-
from jax import Device
|
54
|
-
|
55
|
-
if jax.__version_info__ < (0, 4, 38):
|
56
|
-
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
57
|
-
else:
|
58
|
-
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
|
59
|
-
from jax.core import trace_ctx
|
60
|
-
|
61
|
-
|
62
|
-
@contextmanager
|
63
|
-
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
|
64
|
-
prev = trace_ctx.axis_env
|
65
|
-
try:
|
66
|
-
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
|
67
|
-
yield
|
68
|
-
finally:
|
69
|
-
trace_ctx.set_axis_env(prev)
|
70
|
-
|
71
|
-
if jax.__version_info__ < (0, 6, 0):
|
72
|
-
from jax.util import safe_map, safe_zip, unzip2, wraps
|
73
|
-
|
74
|
-
else:
|
75
|
-
def safe_map(f, *args):
|
76
|
-
args = list(map(list, args))
|
77
|
-
n = len(args[0])
|
78
|
-
for arg in args[1:]:
|
79
|
-
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
80
|
-
return list(map(f, *args))
|
81
|
-
|
82
|
-
|
83
|
-
def safe_zip(*args):
|
84
|
-
args = list(map(list, args))
|
85
|
-
n = len(args[0])
|
86
|
-
for arg in args[1:]:
|
87
|
-
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
88
|
-
return list(zip(*args))
|
89
|
-
|
90
|
-
|
91
|
-
def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
92
|
-
"""Unzip sequence of length-2 tuples into two tuples."""
|
93
|
-
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
94
|
-
# is too permissive about inputs, and does not guarantee a length-2 output.
|
95
|
-
xs: list[T1] = []
|
96
|
-
ys: list[T2] = []
|
97
|
-
for x, y in xys:
|
98
|
-
xs.append(x)
|
99
|
-
ys.append(y)
|
100
|
-
return tuple(xs), tuple(ys)
|
101
|
-
|
102
|
-
|
103
|
-
def fun_name(fun: Callable):
|
104
|
-
name = getattr(fun, "__name__", None)
|
105
|
-
if name is not None:
|
106
|
-
return name
|
107
|
-
if isinstance(fun, partial):
|
108
|
-
return fun_name(fun.func)
|
109
|
-
else:
|
110
|
-
return "<unnamed function>"
|
111
|
-
|
112
|
-
|
113
|
-
def wraps(
|
114
|
-
wrapped: Callable,
|
115
|
-
namestr: str | None = None,
|
116
|
-
docstr: str | None = None,
|
117
|
-
**kwargs,
|
118
|
-
) -> Callable[[T], T]:
|
119
|
-
"""
|
120
|
-
Like functools.wraps, but with finer-grained control over the name and docstring
|
121
|
-
of the resulting function.
|
122
|
-
"""
|
123
|
-
|
124
|
-
def wrapper(fun: T) -> T:
|
125
|
-
try:
|
126
|
-
name = fun_name(wrapped)
|
127
|
-
doc = getattr(wrapped, "__doc__", "") or ""
|
128
|
-
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
129
|
-
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
130
|
-
fun.__name__ = name if namestr is None else namestr.format(fun=name)
|
131
|
-
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
132
|
-
fun.__doc__ = (doc if docstr is None
|
133
|
-
else docstr.format(fun=name, doc=doc, **kwargs))
|
134
|
-
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
|
135
|
-
fun.__wrapped__ = wrapped
|
136
|
-
except Exception:
|
137
|
-
pass
|
138
|
-
return fun
|
139
|
-
|
140
|
-
return wrapper
|
141
|
-
|
142
|
-
|
143
|
-
def to_concrete_aval(aval):
|
144
|
-
aval = get_aval(aval)
|
145
|
-
if isinstance(aval, Tracer):
|
146
|
-
return aval.to_concrete_value()
|
147
|
-
return aval
|
148
|
-
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
from contextlib import contextmanager
|
20
|
+
from functools import partial
|
21
|
+
from typing import Iterable, Hashable, TypeVar, Callable
|
22
|
+
|
23
|
+
import jax
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'ClosedJaxpr',
|
27
|
+
'Primitive',
|
28
|
+
'extend_axis_env_nd',
|
29
|
+
'jaxpr_as_fun',
|
30
|
+
'get_aval',
|
31
|
+
'Tracer',
|
32
|
+
'to_concrete_aval',
|
33
|
+
'safe_map',
|
34
|
+
'safe_zip',
|
35
|
+
'unzip2',
|
36
|
+
'wraps',
|
37
|
+
'Device',
|
38
|
+
'wrap_init',
|
39
|
+
]
|
40
|
+
|
41
|
+
T = TypeVar("T")
|
42
|
+
T1 = TypeVar("T1")
|
43
|
+
T2 = TypeVar("T2")
|
44
|
+
T3 = TypeVar("T3")
|
45
|
+
|
46
|
+
from saiunit._compatible_import import wrap_init
|
47
|
+
|
48
|
+
from jax.core import get_aval, Tracer
|
49
|
+
|
50
|
+
if jax.__version_info__ < (0, 5, 0):
|
51
|
+
from jax.lib.xla_client import Device
|
52
|
+
else:
|
53
|
+
from jax import Device
|
54
|
+
|
55
|
+
if jax.__version_info__ < (0, 4, 38):
|
56
|
+
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
57
|
+
else:
|
58
|
+
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
|
59
|
+
from jax.core import trace_ctx
|
60
|
+
|
61
|
+
|
62
|
+
@contextmanager
|
63
|
+
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
|
64
|
+
prev = trace_ctx.axis_env
|
65
|
+
try:
|
66
|
+
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
|
67
|
+
yield
|
68
|
+
finally:
|
69
|
+
trace_ctx.set_axis_env(prev)
|
70
|
+
|
71
|
+
if jax.__version_info__ < (0, 6, 0):
|
72
|
+
from jax.util import safe_map, safe_zip, unzip2, wraps
|
73
|
+
|
74
|
+
else:
|
75
|
+
def safe_map(f, *args):
|
76
|
+
args = list(map(list, args))
|
77
|
+
n = len(args[0])
|
78
|
+
for arg in args[1:]:
|
79
|
+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
80
|
+
return list(map(f, *args))
|
81
|
+
|
82
|
+
|
83
|
+
def safe_zip(*args):
|
84
|
+
args = list(map(list, args))
|
85
|
+
n = len(args[0])
|
86
|
+
for arg in args[1:]:
|
87
|
+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
88
|
+
return list(zip(*args))
|
89
|
+
|
90
|
+
|
91
|
+
def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
92
|
+
"""Unzip sequence of length-2 tuples into two tuples."""
|
93
|
+
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
94
|
+
# is too permissive about inputs, and does not guarantee a length-2 output.
|
95
|
+
xs: list[T1] = []
|
96
|
+
ys: list[T2] = []
|
97
|
+
for x, y in xys:
|
98
|
+
xs.append(x)
|
99
|
+
ys.append(y)
|
100
|
+
return tuple(xs), tuple(ys)
|
101
|
+
|
102
|
+
|
103
|
+
def fun_name(fun: Callable):
|
104
|
+
name = getattr(fun, "__name__", None)
|
105
|
+
if name is not None:
|
106
|
+
return name
|
107
|
+
if isinstance(fun, partial):
|
108
|
+
return fun_name(fun.func)
|
109
|
+
else:
|
110
|
+
return "<unnamed function>"
|
111
|
+
|
112
|
+
|
113
|
+
def wraps(
|
114
|
+
wrapped: Callable,
|
115
|
+
namestr: str | None = None,
|
116
|
+
docstr: str | None = None,
|
117
|
+
**kwargs,
|
118
|
+
) -> Callable[[T], T]:
|
119
|
+
"""
|
120
|
+
Like functools.wraps, but with finer-grained control over the name and docstring
|
121
|
+
of the resulting function.
|
122
|
+
"""
|
123
|
+
|
124
|
+
def wrapper(fun: T) -> T:
|
125
|
+
try:
|
126
|
+
name = fun_name(wrapped)
|
127
|
+
doc = getattr(wrapped, "__doc__", "") or ""
|
128
|
+
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
129
|
+
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
130
|
+
fun.__name__ = name if namestr is None else namestr.format(fun=name)
|
131
|
+
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
132
|
+
fun.__doc__ = (doc if docstr is None
|
133
|
+
else docstr.format(fun=name, doc=doc, **kwargs))
|
134
|
+
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
|
135
|
+
fun.__wrapped__ = wrapped
|
136
|
+
except Exception:
|
137
|
+
pass
|
138
|
+
return fun
|
139
|
+
|
140
|
+
return wrapper
|
141
|
+
|
142
|
+
|
143
|
+
def to_concrete_aval(aval):
|
144
|
+
aval = get_aval(aval)
|
145
|
+
if isinstance(aval, Tracer):
|
146
|
+
return aval.to_concrete_value()
|
147
|
+
return aval
|
148
|
+
|