imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
ring/ml/optimizer.py
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
from typing import Any, NamedTuple, Optional
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from jax import lax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jax.tree_util import tree_map
|
7
|
+
import optax
|
8
|
+
from optax._src import base
|
9
|
+
from optax._src import numerics
|
10
|
+
from optax._src.transform import add_noise
|
11
|
+
from optax._src.transform import AddNoiseState
|
12
|
+
|
13
|
+
|
14
|
+
def make_optimizer(
|
15
|
+
lr: float,
|
16
|
+
n_episodes: int,
|
17
|
+
n_steps_per_episode: int,
|
18
|
+
adap_clip: Optional[float] = 0.1,
|
19
|
+
glob_clip: Optional[float] = 0.2,
|
20
|
+
skip_large_update_max_normsq: float = 5.0,
|
21
|
+
skip_large_update_warmup: int = 300,
|
22
|
+
inner_opt=optax.lamb,
|
23
|
+
cos_decay_twice: bool = False,
|
24
|
+
scale_grads: Optional[float] = None,
|
25
|
+
**inner_opt_kwargs,
|
26
|
+
):
|
27
|
+
steps = n_steps_per_episode * n_episodes
|
28
|
+
if cos_decay_twice:
|
29
|
+
half_steps = int(steps / 2)
|
30
|
+
schedule = optax.join_schedules(
|
31
|
+
[
|
32
|
+
optax.cosine_decay_schedule(lr, half_steps, 1e-2),
|
33
|
+
optax.cosine_decay_schedule(lr * 1e-2, half_steps),
|
34
|
+
],
|
35
|
+
[half_steps],
|
36
|
+
)
|
37
|
+
else:
|
38
|
+
schedule = optax.cosine_decay_schedule(lr, steps, 1e-7)
|
39
|
+
|
40
|
+
optimizer = optax.chain(
|
41
|
+
(
|
42
|
+
optax.scale_by_learning_rate(scale_grads, flip_sign=False)
|
43
|
+
if scale_grads is not None
|
44
|
+
else optax.identity()
|
45
|
+
),
|
46
|
+
(
|
47
|
+
optax.adaptive_grad_clip(adap_clip)
|
48
|
+
if adap_clip is not None
|
49
|
+
else optax.identity()
|
50
|
+
),
|
51
|
+
optax.clip_by_global_norm(0.2) if glob_clip is not None else optax.identity(),
|
52
|
+
inner_opt(schedule, **inner_opt_kwargs),
|
53
|
+
)
|
54
|
+
optimizer = skip_large_update(
|
55
|
+
optimizer,
|
56
|
+
skip_large_update_max_normsq,
|
57
|
+
max_consecutive_toolarge=6 * 25,
|
58
|
+
warmup=skip_large_update_warmup,
|
59
|
+
)
|
60
|
+
return optimizer
|
61
|
+
|
62
|
+
|
63
|
+
class SkipIfLargeUpdatesState(NamedTuple):
|
64
|
+
toolarge_count: jnp.array
|
65
|
+
count: jnp.array
|
66
|
+
inner_state: Any
|
67
|
+
add_noise_state: AddNoiseState
|
68
|
+
|
69
|
+
|
70
|
+
def _condition_not_toolarge(updates: base.Updates, max_norm_sq: float):
|
71
|
+
norm_sq = jnp.sum(
|
72
|
+
jnp.array([jnp.sum(p**2) for p in jax.tree_util.tree_leaves(updates)])
|
73
|
+
)
|
74
|
+
# This will also return False if `norm_sq` is NaN or Inf.
|
75
|
+
return norm_sq < max_norm_sq
|
76
|
+
|
77
|
+
|
78
|
+
def skip_large_update(
|
79
|
+
inner: base.GradientTransformation,
|
80
|
+
max_norm_sq: float,
|
81
|
+
max_consecutive_toolarge: int,
|
82
|
+
warmup: int = 0,
|
83
|
+
disturb_if_skip: bool = False,
|
84
|
+
disturb_adaptive: bool = False,
|
85
|
+
eta: float = 0.01,
|
86
|
+
gamma: float = 0.55,
|
87
|
+
seed: int = 0,
|
88
|
+
) -> base.GradientTransformation:
|
89
|
+
"Also skips NaNs and Infs."
|
90
|
+
inner = base.with_extra_args_support(inner)
|
91
|
+
|
92
|
+
if disturb_adaptive:
|
93
|
+
raise NotImplementedError
|
94
|
+
|
95
|
+
add_noise_transform = add_noise(eta, gamma, seed)
|
96
|
+
|
97
|
+
def init(params):
|
98
|
+
return SkipIfLargeUpdatesState(
|
99
|
+
toolarge_count=jnp.zeros([], jnp.int32),
|
100
|
+
count=jnp.zeros([], jnp.int32),
|
101
|
+
inner_state=inner.init(params),
|
102
|
+
add_noise_state=add_noise_transform.init(params),
|
103
|
+
)
|
104
|
+
|
105
|
+
def update(updates, state: SkipIfLargeUpdatesState, params=None, **extra_args):
|
106
|
+
inner_state = state.inner_state
|
107
|
+
not_toolarge = _condition_not_toolarge(updates, max_norm_sq)
|
108
|
+
toolarge_count = jnp.where(
|
109
|
+
not_toolarge,
|
110
|
+
jnp.zeros([], jnp.int32),
|
111
|
+
numerics.safe_int32_increment(state.toolarge_count),
|
112
|
+
)
|
113
|
+
|
114
|
+
def do_update(updates):
|
115
|
+
updates, new_inner_state = inner.update(
|
116
|
+
updates, inner_state, params, **extra_args
|
117
|
+
)
|
118
|
+
return updates, new_inner_state, state.add_noise_state
|
119
|
+
|
120
|
+
def reject_update(updates):
|
121
|
+
if disturb_if_skip:
|
122
|
+
updates, new_add_noise_state = add_noise_transform.update(
|
123
|
+
updates, state.add_noise_state, params
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
updates, new_add_noise_state = (
|
127
|
+
tree_map(jnp.zeros_like, updates),
|
128
|
+
state.add_noise_state,
|
129
|
+
)
|
130
|
+
return updates, inner_state, new_add_noise_state
|
131
|
+
|
132
|
+
updates, new_inner_state, new_add_noise_state = lax.cond(
|
133
|
+
jnp.logical_or(
|
134
|
+
jnp.logical_or(not_toolarge, toolarge_count > max_consecutive_toolarge),
|
135
|
+
state.count < warmup,
|
136
|
+
),
|
137
|
+
do_update,
|
138
|
+
reject_update,
|
139
|
+
updates,
|
140
|
+
)
|
141
|
+
|
142
|
+
return updates, SkipIfLargeUpdatesState(
|
143
|
+
toolarge_count=toolarge_count,
|
144
|
+
count=numerics.safe_int32_increment(state.count),
|
145
|
+
inner_state=new_inner_state,
|
146
|
+
add_noise_state=new_add_noise_state,
|
147
|
+
)
|
148
|
+
|
149
|
+
return base.GradientTransformationExtraArgs(init=init, update=update)
|
Binary file
|
ring/ml/ringnet.py
ADDED
@@ -0,0 +1,279 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from pathlib import Path
|
3
|
+
from types import SimpleNamespace
|
4
|
+
from typing import Callable, Optional
|
5
|
+
|
6
|
+
import haiku as hk
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import tree_utils
|
10
|
+
|
11
|
+
from ring.maths import safe_normalize
|
12
|
+
from ring.ml import base as ml_base
|
13
|
+
from ring.utils import pickle_load
|
14
|
+
|
15
|
+
|
16
|
+
def _scan_sys(lam: list[int], f):
|
17
|
+
ys = []
|
18
|
+
for i, p in enumerate(lam):
|
19
|
+
ys.append(f(i, p))
|
20
|
+
return tree_utils.tree_batch(ys, backend="jax")
|
21
|
+
|
22
|
+
|
23
|
+
def _make_rnno_cell_apply_fn(
|
24
|
+
lam: list[int],
|
25
|
+
inner_cell,
|
26
|
+
send_msg,
|
27
|
+
send_output,
|
28
|
+
hidden_state_dim,
|
29
|
+
message_dim,
|
30
|
+
output_transform: Callable,
|
31
|
+
):
|
32
|
+
N = len(lam)
|
33
|
+
parent_array = jnp.array(lam, dtype=jnp.int32)
|
34
|
+
|
35
|
+
def _rnno_cell_apply_fn(inputs, prev_state):
|
36
|
+
empty_message = jnp.zeros((1, message_dim))
|
37
|
+
mailbox = jnp.repeat(empty_message, N, axis=0)
|
38
|
+
|
39
|
+
# message is sent using the hidden state of the last cell
|
40
|
+
# for LSTM `prev_state` is of shape (2 * hidden_state_dim) du to cell state
|
41
|
+
prev_last_hidden_state = prev_state[:, -1, :hidden_state_dim]
|
42
|
+
|
43
|
+
msg = jnp.concatenate(
|
44
|
+
(jax.vmap(send_msg)(prev_last_hidden_state), empty_message)
|
45
|
+
)
|
46
|
+
|
47
|
+
def accumulate_message(link):
|
48
|
+
return jnp.sum(
|
49
|
+
jnp.where(
|
50
|
+
jnp.repeat((parent_array == link)[:, None], message_dim, axis=-1),
|
51
|
+
msg[:-1],
|
52
|
+
mailbox,
|
53
|
+
),
|
54
|
+
axis=0,
|
55
|
+
)
|
56
|
+
|
57
|
+
mailbox = jax.vmap(accumulate_message)(jnp.arange(N))
|
58
|
+
|
59
|
+
def cell_input(i: int, p: int):
|
60
|
+
local_input = inputs[i]
|
61
|
+
local_cell_input = tree_utils.batch_concat_acme(
|
62
|
+
(local_input, msg[p], mailbox[i]), num_batch_dims=0
|
63
|
+
)
|
64
|
+
return local_cell_input
|
65
|
+
|
66
|
+
stacked_cell_input = _scan_sys(lam, cell_input)
|
67
|
+
|
68
|
+
def update_state(cell_input, state):
|
69
|
+
cell_output, state = inner_cell(cell_input, state)
|
70
|
+
output = output_transform(send_output(cell_output))
|
71
|
+
return output, state
|
72
|
+
|
73
|
+
y, state = jax.vmap(update_state)(stacked_cell_input, prev_state)
|
74
|
+
return y, state
|
75
|
+
|
76
|
+
return _rnno_cell_apply_fn
|
77
|
+
|
78
|
+
|
79
|
+
def make_ring(
|
80
|
+
lam: list[int],
|
81
|
+
hidden_state_dim: int = 400,
|
82
|
+
message_dim: int = 200,
|
83
|
+
celltype: str = "gru",
|
84
|
+
stack_rnn_cells: int = 2,
|
85
|
+
send_message_n_layers: int = 1,
|
86
|
+
link_output_dim: int = 4,
|
87
|
+
link_output_normalize: bool = True,
|
88
|
+
link_output_transform: Optional[Callable] = None,
|
89
|
+
layernorm: bool = True,
|
90
|
+
) -> SimpleNamespace:
|
91
|
+
|
92
|
+
if link_output_normalize:
|
93
|
+
assert link_output_transform is None
|
94
|
+
link_output_transform = safe_normalize
|
95
|
+
else:
|
96
|
+
if link_output_transform is None:
|
97
|
+
link_output_transform = lambda x: x
|
98
|
+
|
99
|
+
@hk.without_apply_rng
|
100
|
+
@hk.transform_with_state
|
101
|
+
def forward(X):
|
102
|
+
send_msg = hk.nets.MLP(
|
103
|
+
[hidden_state_dim] * send_message_n_layers + [message_dim]
|
104
|
+
)
|
105
|
+
|
106
|
+
inner_cell = StackedRNNCell(
|
107
|
+
celltype, hidden_state_dim, stack_rnn_cells, layernorm=layernorm
|
108
|
+
)
|
109
|
+
send_output = hk.nets.MLP([hidden_state_dim, link_output_dim])
|
110
|
+
state = hk.get_state(
|
111
|
+
"inner_cell_state",
|
112
|
+
[
|
113
|
+
len(lam),
|
114
|
+
stack_rnn_cells,
|
115
|
+
(hidden_state_dim * 2 if celltype == "lstm" else hidden_state_dim),
|
116
|
+
],
|
117
|
+
init=jnp.zeros,
|
118
|
+
)
|
119
|
+
|
120
|
+
y, state = hk.dynamic_unroll(
|
121
|
+
_make_rnno_cell_apply_fn(
|
122
|
+
lam=lam,
|
123
|
+
inner_cell=inner_cell,
|
124
|
+
send_msg=send_msg,
|
125
|
+
send_output=send_output,
|
126
|
+
hidden_state_dim=hidden_state_dim,
|
127
|
+
message_dim=message_dim,
|
128
|
+
output_transform=link_output_transform,
|
129
|
+
),
|
130
|
+
X,
|
131
|
+
state,
|
132
|
+
)
|
133
|
+
hk.set_state("inner_cell_state", state)
|
134
|
+
return y
|
135
|
+
|
136
|
+
return forward
|
137
|
+
|
138
|
+
|
139
|
+
class StackedRNNCell(hk.Module):
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
celltype: str,
|
143
|
+
hidden_state_dim,
|
144
|
+
stacks: int,
|
145
|
+
layernorm: bool = False,
|
146
|
+
name: str | None = None,
|
147
|
+
):
|
148
|
+
super().__init__(name)
|
149
|
+
cell = {"gru": hk.GRU, "lstm": LSTM}[celltype]
|
150
|
+
|
151
|
+
self.cells = [cell(hidden_state_dim) for _ in range(stacks)]
|
152
|
+
self.layernorm = layernorm
|
153
|
+
|
154
|
+
def __call__(self, x, state):
|
155
|
+
output = x
|
156
|
+
next_state = []
|
157
|
+
for i in range(len(self.cells)):
|
158
|
+
output, next_state_i = self.cells[i](output, state[i])
|
159
|
+
next_state.append(next_state_i)
|
160
|
+
|
161
|
+
if self.layernorm:
|
162
|
+
output = hk.LayerNorm(-1, True, True)(output)
|
163
|
+
|
164
|
+
return output, jnp.stack(next_state)
|
165
|
+
|
166
|
+
|
167
|
+
class LSTM(hk.RNNCore):
|
168
|
+
def __init__(self, hidden_size: int, name=None):
|
169
|
+
super().__init__(name=name)
|
170
|
+
self.hidden_size = hidden_size
|
171
|
+
|
172
|
+
def __call__(
|
173
|
+
self,
|
174
|
+
inputs: jax.Array,
|
175
|
+
prev_state: jax.Array,
|
176
|
+
):
|
177
|
+
if len(inputs.shape) > 2 or not inputs.shape:
|
178
|
+
raise ValueError("LSTM input must be rank-1 or rank-2.")
|
179
|
+
prev_state_h = prev_state[: self.hidden_size]
|
180
|
+
prev_state_c = prev_state[self.hidden_size :]
|
181
|
+
x_and_h = jnp.concatenate([inputs, prev_state_h], axis=-1)
|
182
|
+
gated = hk.Linear(4 * self.hidden_size)(x_and_h)
|
183
|
+
i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
|
184
|
+
f = jax.nn.sigmoid(f + 1) # Forget bias, as in sonnet.
|
185
|
+
c = f * prev_state_c + jax.nn.sigmoid(i) * jnp.tanh(g)
|
186
|
+
h = jax.nn.sigmoid(o) * jnp.tanh(c)
|
187
|
+
return h, jnp.concatenate((h, c))
|
188
|
+
|
189
|
+
def initial_state(self, batch_size: int | None):
|
190
|
+
raise NotImplementedError
|
191
|
+
|
192
|
+
|
193
|
+
class RING(ml_base.AbstractFilter):
|
194
|
+
def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):
|
195
|
+
self.forward_lam_factory = partial(make_ring, **kwargs)
|
196
|
+
self.params = self._load_params(params)
|
197
|
+
self.lam = lam
|
198
|
+
self._name = name
|
199
|
+
|
200
|
+
if jit:
|
201
|
+
self.apply = jax.jit(self.apply, static_argnames="lam")
|
202
|
+
|
203
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
204
|
+
if lam is None:
|
205
|
+
assert self.lam is not None
|
206
|
+
lam = self.lam
|
207
|
+
|
208
|
+
return super().apply(X, params, state, y, tuple(lam))
|
209
|
+
|
210
|
+
def init(self, bs: Optional[int] = None, X=None, lam=None, seed: int = 1):
|
211
|
+
assert X is not None, "Providing `X` via in `ringnet.init(X=X)` is required"
|
212
|
+
if bs is not None:
|
213
|
+
assert X.ndim == 4
|
214
|
+
|
215
|
+
if X.ndim == 4:
|
216
|
+
if bs is not None:
|
217
|
+
assert bs == X.shape[0]
|
218
|
+
else:
|
219
|
+
bs = X.shape[0]
|
220
|
+
X = X[0]
|
221
|
+
|
222
|
+
# (T, N, F) -> (1, N, F) for faster .init call
|
223
|
+
X = X[0:1]
|
224
|
+
|
225
|
+
if lam is None:
|
226
|
+
assert self.lam is not None
|
227
|
+
lam = self.lam
|
228
|
+
|
229
|
+
key = jax.random.PRNGKey(seed)
|
230
|
+
params, state = self.forward_lam_factory(lam=lam).init(key, X)
|
231
|
+
|
232
|
+
if bs is not None:
|
233
|
+
state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
|
234
|
+
|
235
|
+
return params, state
|
236
|
+
|
237
|
+
def _apply_batched(self, X, params, state, y, lam):
|
238
|
+
if (params is None and self.params is None) or state is None:
|
239
|
+
_params, _state = self.init(bs=X.shape[0], X=X, lam=lam)
|
240
|
+
|
241
|
+
if params is None and self.params is None:
|
242
|
+
params = _params
|
243
|
+
elif params is None:
|
244
|
+
params = self.params
|
245
|
+
else:
|
246
|
+
pass
|
247
|
+
|
248
|
+
if state is None:
|
249
|
+
state = _state
|
250
|
+
|
251
|
+
yhat, next_state = jax.vmap(
|
252
|
+
self.forward_lam_factory(lam=lam).apply, in_axes=(None, 0, 0)
|
253
|
+
)(params, state, X)
|
254
|
+
|
255
|
+
return yhat, next_state
|
256
|
+
|
257
|
+
@staticmethod
|
258
|
+
def _load_params(params: str | dict | None | Path):
|
259
|
+
assert isinstance(params, (str, dict, type(None), Path))
|
260
|
+
if isinstance(params, (Path, str)):
|
261
|
+
return pickle_load(params)
|
262
|
+
return params
|
263
|
+
|
264
|
+
def nojit(self) -> "RING":
|
265
|
+
ringnet = RING(params=self.params, lam=self.lam, jit=False)
|
266
|
+
ringnet.forward_lam_factory = self.forward_lam_factory
|
267
|
+
return ringnet
|
268
|
+
|
269
|
+
def _pre_save(self, params=None, lam=None) -> None:
|
270
|
+
if params is not None:
|
271
|
+
self.params = params
|
272
|
+
if lam is not None:
|
273
|
+
self.lam = lam
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def _post_load(ringnet: "RING", jit: bool = True) -> "RING":
|
277
|
+
if jit:
|
278
|
+
ringnet.apply = jax.jit(ringnet.apply, static_argnames="lam")
|
279
|
+
return ringnet
|