brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/augment/_random.py
CHANGED
@@ -1,151 +1,151 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import Callable, Sequence, Union
|
18
|
-
|
19
|
-
from brainstate.random import DEFAULT, RandomState
|
20
|
-
from brainstate.typing import Missing
|
21
|
-
from brainstate.util import PrettyObject
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
'restore_rngs'
|
25
|
-
]
|
26
|
-
|
27
|
-
|
28
|
-
class RngRestore(PrettyObject):
|
29
|
-
"""
|
30
|
-
Backup and restore the random state of a sequence of RandomState instances.
|
31
|
-
|
32
|
-
This class provides functionality to save the current state of multiple
|
33
|
-
RandomState instances and later restore them to their saved states.
|
34
|
-
|
35
|
-
Attributes:
|
36
|
-
rngs (Sequence[RandomState]): A sequence of RandomState instances to manage.
|
37
|
-
rng_keys (list): A list to store the backed up random keys.
|
38
|
-
"""
|
39
|
-
|
40
|
-
def __init__(self, rngs: Sequence[RandomState]):
|
41
|
-
"""
|
42
|
-
Initialize the RngRestore instance.
|
43
|
-
|
44
|
-
Args:
|
45
|
-
rngs (Sequence[RandomState]): A sequence of RandomState instances
|
46
|
-
whose states will be managed.
|
47
|
-
"""
|
48
|
-
self.rngs: Sequence[RandomState] = rngs
|
49
|
-
self.rng_keys = []
|
50
|
-
|
51
|
-
def backup(self):
|
52
|
-
"""
|
53
|
-
Backup the current random key of the RandomState instances.
|
54
|
-
|
55
|
-
This method saves the current value (state) of each RandomState
|
56
|
-
instance in the rngs sequence.
|
57
|
-
"""
|
58
|
-
self.rng_keys = [rng.value for rng in self.rngs]
|
59
|
-
|
60
|
-
def restore(self):
|
61
|
-
"""
|
62
|
-
Restore the random key of the RandomState instances.
|
63
|
-
|
64
|
-
This method restores each RandomState instance to its previously
|
65
|
-
saved state. It raises an error if the number of saved keys doesn't
|
66
|
-
match the number of RandomState instances.
|
67
|
-
|
68
|
-
Raises:
|
69
|
-
ValueError: If the number of saved random keys does not match
|
70
|
-
the number of RandomState instances.
|
71
|
-
"""
|
72
|
-
if len(self.rng_keys) != len(self.rngs):
|
73
|
-
raise ValueError('The number of random keys does not match the number of random states.')
|
74
|
-
for rng, key in zip(self.rngs, self.rng_keys):
|
75
|
-
rng.restore_value(key)
|
76
|
-
self.rng_keys.clear()
|
77
|
-
|
78
|
-
|
79
|
-
def _rng_backup(
|
80
|
-
fn: Callable,
|
81
|
-
rngs: Union[RandomState, Sequence[RandomState]]
|
82
|
-
) -> Callable:
|
83
|
-
rng_restorer = RngRestore(rngs)
|
84
|
-
|
85
|
-
@functools.wraps(fn)
|
86
|
-
def wrapper(*args, **kwargs):
|
87
|
-
# backup the random state
|
88
|
-
rng_restorer.backup()
|
89
|
-
# call the function
|
90
|
-
out = fn(*args, **kwargs)
|
91
|
-
# restore the random state
|
92
|
-
rng_restorer.restore()
|
93
|
-
return out
|
94
|
-
|
95
|
-
return wrapper
|
96
|
-
|
97
|
-
|
98
|
-
def restore_rngs(
|
99
|
-
fn: Callable = Missing(),
|
100
|
-
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
101
|
-
) -> Callable:
|
102
|
-
"""
|
103
|
-
Decorator to backup and restore the random state before and after a function call.
|
104
|
-
|
105
|
-
This function can be used as a decorator or called directly. It ensures that the
|
106
|
-
random state of the specified RandomState instances is preserved across function calls,
|
107
|
-
which is useful for maintaining reproducibility in stochastic operations.
|
108
|
-
|
109
|
-
Parameters
|
110
|
-
----------
|
111
|
-
fn : Callable, optional
|
112
|
-
The function to be wrapped. If not provided, the decorator can be used
|
113
|
-
with parameters.
|
114
|
-
rngs : Union[RandomState, Sequence[RandomState]], optional
|
115
|
-
The random state(s) to be backed up and restored. This can be a single
|
116
|
-
RandomState instance or a sequence of RandomState instances. If not provided,
|
117
|
-
the default RandomState instance will be used.
|
118
|
-
|
119
|
-
Returns
|
120
|
-
-------
|
121
|
-
Callable
|
122
|
-
If `fn` is provided, returns the wrapped function that will backup the
|
123
|
-
random state before execution and restore it afterwards.
|
124
|
-
If `fn` is not provided, returns a partial function that can be used as
|
125
|
-
a decorator with the specified `rngs`.
|
126
|
-
|
127
|
-
Raises
|
128
|
-
------
|
129
|
-
AssertionError
|
130
|
-
If `rngs` is not a RandomState instance or a sequence of RandomState instances.
|
131
|
-
|
132
|
-
Examples
|
133
|
-
--------
|
134
|
-
>>> @restore_rngs
|
135
|
-
... def my_random_function():
|
136
|
-
... return random.random()
|
137
|
-
|
138
|
-
>>> rng = RandomState(42)
|
139
|
-
>>> @restore_rngs(rngs=rng)
|
140
|
-
... def another_random_function():
|
141
|
-
... return rng.random()
|
142
|
-
"""
|
143
|
-
if isinstance(fn, Missing):
|
144
|
-
return functools.partial(restore_rngs, rngs=rngs)
|
145
|
-
|
146
|
-
if isinstance(rngs, RandomState):
|
147
|
-
rngs = [rngs]
|
148
|
-
assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
149
|
-
for rng in rngs:
|
150
|
-
assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
151
|
-
return _rng_backup(fn, rngs=rngs)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import Callable, Sequence, Union
|
18
|
+
|
19
|
+
from brainstate.random import DEFAULT, RandomState
|
20
|
+
from brainstate.typing import Missing
|
21
|
+
from brainstate.util import PrettyObject
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
'restore_rngs'
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
class RngRestore(PrettyObject):
|
29
|
+
"""
|
30
|
+
Backup and restore the random state of a sequence of RandomState instances.
|
31
|
+
|
32
|
+
This class provides functionality to save the current state of multiple
|
33
|
+
RandomState instances and later restore them to their saved states.
|
34
|
+
|
35
|
+
Attributes:
|
36
|
+
rngs (Sequence[RandomState]): A sequence of RandomState instances to manage.
|
37
|
+
rng_keys (list): A list to store the backed up random keys.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self, rngs: Sequence[RandomState]):
|
41
|
+
"""
|
42
|
+
Initialize the RngRestore instance.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
rngs (Sequence[RandomState]): A sequence of RandomState instances
|
46
|
+
whose states will be managed.
|
47
|
+
"""
|
48
|
+
self.rngs: Sequence[RandomState] = rngs
|
49
|
+
self.rng_keys = []
|
50
|
+
|
51
|
+
def backup(self):
|
52
|
+
"""
|
53
|
+
Backup the current random key of the RandomState instances.
|
54
|
+
|
55
|
+
This method saves the current value (state) of each RandomState
|
56
|
+
instance in the rngs sequence.
|
57
|
+
"""
|
58
|
+
self.rng_keys = [rng.value for rng in self.rngs]
|
59
|
+
|
60
|
+
def restore(self):
|
61
|
+
"""
|
62
|
+
Restore the random key of the RandomState instances.
|
63
|
+
|
64
|
+
This method restores each RandomState instance to its previously
|
65
|
+
saved state. It raises an error if the number of saved keys doesn't
|
66
|
+
match the number of RandomState instances.
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValueError: If the number of saved random keys does not match
|
70
|
+
the number of RandomState instances.
|
71
|
+
"""
|
72
|
+
if len(self.rng_keys) != len(self.rngs):
|
73
|
+
raise ValueError('The number of random keys does not match the number of random states.')
|
74
|
+
for rng, key in zip(self.rngs, self.rng_keys):
|
75
|
+
rng.restore_value(key)
|
76
|
+
self.rng_keys.clear()
|
77
|
+
|
78
|
+
|
79
|
+
def _rng_backup(
|
80
|
+
fn: Callable,
|
81
|
+
rngs: Union[RandomState, Sequence[RandomState]]
|
82
|
+
) -> Callable:
|
83
|
+
rng_restorer = RngRestore(rngs)
|
84
|
+
|
85
|
+
@functools.wraps(fn)
|
86
|
+
def wrapper(*args, **kwargs):
|
87
|
+
# backup the random state
|
88
|
+
rng_restorer.backup()
|
89
|
+
# call the function
|
90
|
+
out = fn(*args, **kwargs)
|
91
|
+
# restore the random state
|
92
|
+
rng_restorer.restore()
|
93
|
+
return out
|
94
|
+
|
95
|
+
return wrapper
|
96
|
+
|
97
|
+
|
98
|
+
def restore_rngs(
|
99
|
+
fn: Callable = Missing(),
|
100
|
+
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
101
|
+
) -> Callable:
|
102
|
+
"""
|
103
|
+
Decorator to backup and restore the random state before and after a function call.
|
104
|
+
|
105
|
+
This function can be used as a decorator or called directly. It ensures that the
|
106
|
+
random state of the specified RandomState instances is preserved across function calls,
|
107
|
+
which is useful for maintaining reproducibility in stochastic operations.
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
fn : Callable, optional
|
112
|
+
The function to be wrapped. If not provided, the decorator can be used
|
113
|
+
with parameters.
|
114
|
+
rngs : Union[RandomState, Sequence[RandomState]], optional
|
115
|
+
The random state(s) to be backed up and restored. This can be a single
|
116
|
+
RandomState instance or a sequence of RandomState instances. If not provided,
|
117
|
+
the default RandomState instance will be used.
|
118
|
+
|
119
|
+
Returns
|
120
|
+
-------
|
121
|
+
Callable
|
122
|
+
If `fn` is provided, returns the wrapped function that will backup the
|
123
|
+
random state before execution and restore it afterwards.
|
124
|
+
If `fn` is not provided, returns a partial function that can be used as
|
125
|
+
a decorator with the specified `rngs`.
|
126
|
+
|
127
|
+
Raises
|
128
|
+
------
|
129
|
+
AssertionError
|
130
|
+
If `rngs` is not a RandomState instance or a sequence of RandomState instances.
|
131
|
+
|
132
|
+
Examples
|
133
|
+
--------
|
134
|
+
>>> @restore_rngs
|
135
|
+
... def my_random_function():
|
136
|
+
... return random.random()
|
137
|
+
|
138
|
+
>>> rng = RandomState(42)
|
139
|
+
>>> @restore_rngs(rngs=rng)
|
140
|
+
... def another_random_function():
|
141
|
+
... return rng.random()
|
142
|
+
"""
|
143
|
+
if isinstance(fn, Missing):
|
144
|
+
return functools.partial(restore_rngs, rngs=rngs)
|
145
|
+
|
146
|
+
if isinstance(rngs, RandomState):
|
147
|
+
rngs = [rngs]
|
148
|
+
assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
149
|
+
for rng in rngs:
|
150
|
+
assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
151
|
+
return _rng_backup(fn, rngs=rngs)
|
brainstate/compile/__init__.py
CHANGED
@@ -1,38 +1,38 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
This module contains the functions for the compilation of JAX code.
|
18
|
-
"""
|
19
|
-
|
20
|
-
from ._ad_checkpoint import checkpoint, remat
|
21
|
-
from ._conditions import cond, switch, ifelse
|
22
|
-
from ._error_if import jit_error_if
|
23
|
-
from ._jit import jit
|
24
|
-
from ._loop_collect_return import scan, checkpointed_scan, for_loop, checkpointed_for_loop
|
25
|
-
from ._loop_no_collection import while_loop, bounded_while_loop
|
26
|
-
from ._make_jaxpr import StatefulFunction, make_jaxpr
|
27
|
-
from ._progress_bar import ProgressBar
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'checkpoint', 'remat',
|
31
|
-
'cond', 'switch', 'ifelse',
|
32
|
-
'jit_error_if',
|
33
|
-
'jit',
|
34
|
-
'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
|
35
|
-
'while_loop', 'bounded_while_loop',
|
36
|
-
'StatefulFunction', 'make_jaxpr',
|
37
|
-
'ProgressBar',
|
38
|
-
]
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
This module contains the functions for the compilation of JAX code.
|
18
|
+
"""
|
19
|
+
|
20
|
+
from ._ad_checkpoint import checkpoint, remat
|
21
|
+
from ._conditions import cond, switch, ifelse
|
22
|
+
from ._error_if import jit_error_if
|
23
|
+
from ._jit import jit
|
24
|
+
from ._loop_collect_return import scan, checkpointed_scan, for_loop, checkpointed_for_loop
|
25
|
+
from ._loop_no_collection import while_loop, bounded_while_loop
|
26
|
+
from ._make_jaxpr import StatefulFunction, make_jaxpr
|
27
|
+
from ._progress_bar import ProgressBar
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'checkpoint', 'remat',
|
31
|
+
'cond', 'switch', 'ifelse',
|
32
|
+
'jit_error_if',
|
33
|
+
'jit',
|
34
|
+
'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
|
35
|
+
'while_loop', 'bounded_while_loop',
|
36
|
+
'StatefulFunction', 'make_jaxpr',
|
37
|
+
'ProgressBar',
|
38
|
+
]
|