imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
ring/ml/train.py
ADDED
@@ -0,0 +1,318 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Callable, Optional, Tuple
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import optax
|
8
|
+
from ring import maths
|
9
|
+
from ring.algorithms.generator import types
|
10
|
+
from ring.ml import base as ml_base
|
11
|
+
from ring.ml import callbacks as ml_callbacks
|
12
|
+
from ring.ml import ml_utils
|
13
|
+
from ring.ml import training_loop
|
14
|
+
from ring.utils import distribute_batchsize
|
15
|
+
from ring.utils import expand_batchsize
|
16
|
+
from ring.utils import parse_path
|
17
|
+
from ring.utils import pickle_load
|
18
|
+
import tree_utils
|
19
|
+
|
20
|
+
import wandb
|
21
|
+
|
22
|
+
# (T, N, F) -> Scalar
|
23
|
+
LOSS_FN = Callable[[jax.Array, jax.Array], float]
|
24
|
+
_default_loss_fn = lambda q, qhat: maths.angle_error(q, qhat) ** 2
|
25
|
+
|
26
|
+
# reduces (batch_axis, time_axis) -> Scalar
|
27
|
+
ACCUMULATOR_FN = Callable[[jax.Array], float]
|
28
|
+
# Loss_fn here is: (F,) -> Scalar
|
29
|
+
METRICES = dict[str, Tuple[LOSS_FN, ACCUMULATOR_FN]]
|
30
|
+
_default_metrices = {
|
31
|
+
"mae_deg": (
|
32
|
+
lambda q, qhat: maths.angle_error(q, qhat),
|
33
|
+
lambda arr: jnp.rad2deg(jnp.mean(arr, axis=(0, 1))),
|
34
|
+
),
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
def _build_step_fn(
|
39
|
+
metric_fn: LOSS_FN,
|
40
|
+
filter: ml_base.AbstractFilter,
|
41
|
+
optimizer,
|
42
|
+
tbp,
|
43
|
+
):
|
44
|
+
"""Build step function that optimizes filter parameters based on `metric_fn`.
|
45
|
+
`initial_state` has shape (pmap, vmap, state_dim)"""
|
46
|
+
|
47
|
+
@partial(jax.value_and_grad, has_aux=True)
|
48
|
+
def loss_fn(params, state, X, y):
|
49
|
+
yhat, state = filter.apply(params=params, state=state, X=X)
|
50
|
+
# this vmap maps along batch-axis, not time-axis
|
51
|
+
# time-axis is handled by `metric_fn`
|
52
|
+
pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
|
53
|
+
error_tree = jax.tree_map(pipe, y, yhat)
|
54
|
+
return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
|
55
|
+
|
56
|
+
@partial(
|
57
|
+
jax.pmap,
|
58
|
+
in_axes=(None, 0, 0, 0),
|
59
|
+
out_axes=((None, 0), None),
|
60
|
+
axis_name="devices",
|
61
|
+
)
|
62
|
+
def pmapped_loss_fn(params, state, X, y):
|
63
|
+
pmean = lambda arr: jax.lax.pmean(arr, axis_name="devices")
|
64
|
+
(loss, state), grads = loss_fn(params, state, X, y)
|
65
|
+
return (pmean(loss), state), pmean(grads)
|
66
|
+
|
67
|
+
@jax.jit
|
68
|
+
def apply_grads(grads, params, opt_state):
|
69
|
+
updates, opt_state = optimizer.update(grads, opt_state, params)
|
70
|
+
params = optax.apply_updates(params, updates)
|
71
|
+
return params, opt_state
|
72
|
+
|
73
|
+
initial_state = None
|
74
|
+
|
75
|
+
def step_fn(params, opt_state, X, y):
|
76
|
+
assert X.ndim == y.ndim == 4
|
77
|
+
B, T, N, F = X.shape
|
78
|
+
pmap_size, vmap_size = distribute_batchsize(B)
|
79
|
+
|
80
|
+
nonlocal initial_state
|
81
|
+
if initial_state is None:
|
82
|
+
initial_state = expand_batchsize(filter.init(B, X)[1], pmap_size, vmap_size)
|
83
|
+
|
84
|
+
X, y = expand_batchsize((X, y), pmap_size, vmap_size)
|
85
|
+
|
86
|
+
state = initial_state
|
87
|
+
debug_grads = []
|
88
|
+
for i, (X_tbp, y_tbp) in enumerate(
|
89
|
+
tree_utils.tree_split((X, y), int(T / tbp), axis=-3)
|
90
|
+
):
|
91
|
+
(loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
|
92
|
+
debug_grads.append(grads)
|
93
|
+
state = jax.lax.stop_gradient(state)
|
94
|
+
params, opt_state = apply_grads(grads, params, opt_state)
|
95
|
+
|
96
|
+
return params, opt_state, {"loss": loss}, debug_grads
|
97
|
+
|
98
|
+
return step_fn
|
99
|
+
|
100
|
+
|
101
|
+
def train_fn(
|
102
|
+
generator: types.BatchedGenerator,
|
103
|
+
n_episodes: int,
|
104
|
+
filter: ml_base.AbstractFilter | ml_base.AbstractFilterWrapper,
|
105
|
+
optimizer: Optional[optax.GradientTransformation] = optax.adam(1e-3),
|
106
|
+
tbp: int = 1000,
|
107
|
+
loggers: list[ml_utils.Logger] = [],
|
108
|
+
callbacks: list[training_loop.TrainingLoopCallback] = [],
|
109
|
+
checkpoint: Optional[str] = None,
|
110
|
+
seed_network: int = 1,
|
111
|
+
seed_generator: int = 2,
|
112
|
+
callback_save_params: bool | str = False,
|
113
|
+
callback_save_params_track_metrices: Optional[list[list[str]]] = None,
|
114
|
+
callback_kill_if_grads_larger: Optional[float] = None,
|
115
|
+
callback_kill_if_nan: bool = False,
|
116
|
+
callback_kill_after_episode: Optional[int] = None,
|
117
|
+
callback_kill_after_seconds: Optional[float] = None,
|
118
|
+
callback_kill_tag: Optional[str] = None,
|
119
|
+
callback_create_checkpoint: bool = True,
|
120
|
+
loss_fn: LOSS_FN = _default_loss_fn,
|
121
|
+
metrices: Optional[METRICES] = _default_metrices,
|
122
|
+
link_names: Optional[list[str]] = None,
|
123
|
+
) -> bool:
|
124
|
+
"""Trains RNNO
|
125
|
+
|
126
|
+
Args:
|
127
|
+
generator (Callable): output `build_generator`
|
128
|
+
n_episodes (int): number of episodes to train for
|
129
|
+
network (hk.TransformedWithState): RNNO network
|
130
|
+
optimizer (_type_, optional): optimizer, see optimizer.py module
|
131
|
+
tbp (int, optional): Truncated backpropagation through time step size
|
132
|
+
tbp_skip (int, optional): Skip `tbp_skip` number of first steps per epoch.
|
133
|
+
tbp_skip_keep_grads (bool, optional): Keeps grads between first `tbp_skip`
|
134
|
+
steps per epoch.
|
135
|
+
loggers: list of Loggers used to log the training progress.
|
136
|
+
callbacks: callbacks of the TrainingLoop.
|
137
|
+
initial_params: If given uses as initial parameters.
|
138
|
+
key_network: PRNG Key that inits the network state and parameters.
|
139
|
+
key_generator: PRNG Key that inits the data stream of the generator.
|
140
|
+
|
141
|
+
Returns: bool
|
142
|
+
Wether or not the training run was killed by a callback.
|
143
|
+
"""
|
144
|
+
|
145
|
+
if checkpoint is not None:
|
146
|
+
checkpoint = Path(checkpoint).with_suffix(".pickle")
|
147
|
+
recv_checkpoint: dict = pickle_load(checkpoint)
|
148
|
+
filter.params = recv_checkpoint["params"]
|
149
|
+
opt_state = recv_checkpoint["opt_state"]
|
150
|
+
|
151
|
+
filter = filter.nojit()
|
152
|
+
|
153
|
+
filter_params = filter.search_attr("params")
|
154
|
+
if filter_params is None:
|
155
|
+
X, _ = generator(jax.random.PRNGKey(1))
|
156
|
+
filter_params, _ = filter.init(X=X, seed=seed_network)
|
157
|
+
del X
|
158
|
+
|
159
|
+
if checkpoint is None:
|
160
|
+
opt_state = optimizer.init(filter_params)
|
161
|
+
|
162
|
+
step_fn = _build_step_fn(
|
163
|
+
loss_fn,
|
164
|
+
filter,
|
165
|
+
optimizer,
|
166
|
+
tbp=tbp,
|
167
|
+
)
|
168
|
+
|
169
|
+
default_callbacks = []
|
170
|
+
if metrices is not None:
|
171
|
+
eval_fn = _build_eval_fn(metrices, filter, link_names)
|
172
|
+
default_callbacks.append(_DefaultEvalFnCallback(eval_fn))
|
173
|
+
|
174
|
+
if callback_kill_tag is not None:
|
175
|
+
default_callbacks.append(ml_callbacks.WandbKillRun(stop_tag=callback_kill_tag))
|
176
|
+
|
177
|
+
if not (callback_save_params is False):
|
178
|
+
if callback_save_params is True:
|
179
|
+
callback_save_params = f"~/params/{ml_utils.unique_id()}.pickle"
|
180
|
+
default_callbacks.append(
|
181
|
+
ml_callbacks.SaveParamsTrainingLoopCallback(callback_save_params)
|
182
|
+
)
|
183
|
+
|
184
|
+
if callback_kill_if_grads_larger is not None:
|
185
|
+
default_callbacks.append(
|
186
|
+
ml_callbacks.LogGradsTrainingLoopCallBack(
|
187
|
+
callback_kill_if_grads_larger, consecutive_larger=18
|
188
|
+
)
|
189
|
+
)
|
190
|
+
|
191
|
+
if callback_kill_if_nan:
|
192
|
+
default_callbacks.append(ml_callbacks.NanKillRunCallback())
|
193
|
+
|
194
|
+
# always log, because we also want `i_epsiode` to be logged in wandb
|
195
|
+
default_callbacks.append(
|
196
|
+
ml_callbacks.LogEpisodeTrainingLoopCallback(callback_kill_after_episode)
|
197
|
+
)
|
198
|
+
|
199
|
+
if callback_kill_after_seconds is not None:
|
200
|
+
default_callbacks.append(
|
201
|
+
ml_callbacks.TimingKillRunCallback(callback_kill_after_seconds)
|
202
|
+
)
|
203
|
+
|
204
|
+
if callback_create_checkpoint:
|
205
|
+
default_callbacks.append(ml_callbacks.CheckpointCallback())
|
206
|
+
|
207
|
+
callbacks_all = default_callbacks + callbacks
|
208
|
+
|
209
|
+
# we add this callback afterwards because it might require the metrices calculated
|
210
|
+
# from one of the user-provided callbacks
|
211
|
+
if callback_save_params_track_metrices is not None:
|
212
|
+
assert (
|
213
|
+
callback_save_params is not None
|
214
|
+
), "Required field if `callback_save_params_track_metrices` is set. Used below."
|
215
|
+
|
216
|
+
callbacks_all.append(
|
217
|
+
ml_callbacks.SaveParamsTrainingLoopCallback(
|
218
|
+
path_to_file=parse_path(callback_save_params, extension=""),
|
219
|
+
last_n_params=3,
|
220
|
+
track_metrices=callback_save_params_track_metrices,
|
221
|
+
cleanup=False,
|
222
|
+
)
|
223
|
+
)
|
224
|
+
|
225
|
+
# if wandb is initialized, then add the appropriate logger
|
226
|
+
if wandb.run is not None:
|
227
|
+
wandb_logger_found = False
|
228
|
+
for logger in loggers:
|
229
|
+
if isinstance(logger, ml_utils.WandbLogger):
|
230
|
+
wandb_logger_found = True
|
231
|
+
if not wandb_logger_found:
|
232
|
+
loggers.append(ml_utils.WandbLogger())
|
233
|
+
|
234
|
+
loop = training_loop.TrainingLoop(
|
235
|
+
jax.random.PRNGKey(seed_generator),
|
236
|
+
generator,
|
237
|
+
filter_params,
|
238
|
+
opt_state,
|
239
|
+
step_fn,
|
240
|
+
loggers=loggers,
|
241
|
+
callbacks=callbacks_all,
|
242
|
+
)
|
243
|
+
|
244
|
+
return loop.run(n_episodes)
|
245
|
+
|
246
|
+
|
247
|
+
def _arr_to_dict(y: jax.Array, link_names: list[str] | None):
|
248
|
+
assert y.ndim == 4
|
249
|
+
B, T, N, F = y.shape
|
250
|
+
|
251
|
+
if link_names is None:
|
252
|
+
link_names = ml_utils._unknown_link_names(N)
|
253
|
+
|
254
|
+
return {name: y[..., i, :] for i, name in enumerate(link_names)}
|
255
|
+
|
256
|
+
|
257
|
+
def _build_eval_fn(
|
258
|
+
eval_metrices: dict[str, Tuple[Callable, Callable]],
|
259
|
+
filter: ml_base.AbstractFilter,
|
260
|
+
link_names: Optional[list[str]] = None,
|
261
|
+
):
|
262
|
+
"""Build function that evaluates the filter performance."""
|
263
|
+
|
264
|
+
def eval_fn(params, state, X, y):
|
265
|
+
yhat, _ = filter.apply(params=params, state=state, X=X)
|
266
|
+
|
267
|
+
y = _arr_to_dict(y, link_names)
|
268
|
+
yhat = _arr_to_dict(yhat, link_names)
|
269
|
+
|
270
|
+
values = {}
|
271
|
+
for metric_name, (metric_fn, reduce_fn) in eval_metrices.items():
|
272
|
+
assert (
|
273
|
+
metric_name not in values
|
274
|
+
), f"The metric identitifier {metric_name} is not unique"
|
275
|
+
|
276
|
+
pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
|
277
|
+
values.update({metric_name: jax.tree_map(pipe, y, yhat)})
|
278
|
+
|
279
|
+
return values
|
280
|
+
|
281
|
+
@partial(jax.pmap, in_axes=(None, 0, 0, 0), out_axes=None, axis_name="devices")
|
282
|
+
def pmapped_eval_fn(params, state, X, y):
|
283
|
+
pmean = lambda arr: jax.lax.pmean(arr, axis_name="devices")
|
284
|
+
values = eval_fn(params, state, X, y)
|
285
|
+
return pmean(values)
|
286
|
+
|
287
|
+
initial_state = None
|
288
|
+
|
289
|
+
def expand_then_pmap_eval_fn(params, X, y):
|
290
|
+
assert X.ndim == y.ndim == 4
|
291
|
+
B, T, N, F = X.shape
|
292
|
+
pmap_size, vmap_size = distribute_batchsize(B)
|
293
|
+
|
294
|
+
nonlocal initial_state
|
295
|
+
if initial_state is None:
|
296
|
+
initial_state = expand_batchsize(filter.init(B, X)[1], pmap_size, vmap_size)
|
297
|
+
|
298
|
+
X, y = expand_batchsize((X, y), pmap_size, vmap_size)
|
299
|
+
return pmapped_eval_fn(params, initial_state, X, y)
|
300
|
+
|
301
|
+
return expand_then_pmap_eval_fn
|
302
|
+
|
303
|
+
|
304
|
+
class _DefaultEvalFnCallback(training_loop.TrainingLoopCallback):
|
305
|
+
def __init__(self, eval_fn):
|
306
|
+
self.eval_fn = eval_fn
|
307
|
+
|
308
|
+
def after_training_step(
|
309
|
+
self,
|
310
|
+
i_episode: int,
|
311
|
+
metrices: dict,
|
312
|
+
params: dict,
|
313
|
+
grads: list[dict],
|
314
|
+
sample_eval: dict,
|
315
|
+
loggers: list[ml_utils.Logger],
|
316
|
+
opt_state,
|
317
|
+
):
|
318
|
+
metrices.update(self.eval_fn(params, sample_eval[0], sample_eval[1]))
|
ring/ml/training_loop.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
import random
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import jax
|
5
|
+
from ring.algorithms import Generator
|
6
|
+
from ring.ml import ml_utils
|
7
|
+
import tqdm
|
8
|
+
import tree_utils
|
9
|
+
|
10
|
+
_KILL_RUN = False
|
11
|
+
|
12
|
+
|
13
|
+
def send_kill_run_signal(value: bool = True) -> None:
|
14
|
+
global _KILL_RUN
|
15
|
+
_KILL_RUN = value
|
16
|
+
|
17
|
+
|
18
|
+
def recv_kill_run_signal() -> bool:
|
19
|
+
global _KILL_RUN
|
20
|
+
return _KILL_RUN
|
21
|
+
|
22
|
+
|
23
|
+
class TrainingLoopCallback:
|
24
|
+
def after_training_step(
|
25
|
+
self,
|
26
|
+
i_episode: int,
|
27
|
+
metrices: dict,
|
28
|
+
params: dict,
|
29
|
+
grads: list[dict],
|
30
|
+
sample_eval: dict,
|
31
|
+
loggers: list[ml_utils.Logger],
|
32
|
+
opt_state: tree_utils.PyTree,
|
33
|
+
) -> None:
|
34
|
+
pass
|
35
|
+
|
36
|
+
def close(self):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class TrainingLoop:
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
key,
|
44
|
+
generator: Generator,
|
45
|
+
params,
|
46
|
+
opt_state,
|
47
|
+
step_fn,
|
48
|
+
loggers: list[ml_utils.Logger],
|
49
|
+
callbacks: list[TrainingLoopCallback] = [],
|
50
|
+
cycle_seed: Optional[int] = None,
|
51
|
+
):
|
52
|
+
self._key = key
|
53
|
+
self.i_episode = -1
|
54
|
+
self._generator = generator
|
55
|
+
self._params = params
|
56
|
+
self._opt_state = opt_state
|
57
|
+
self._step_fn = step_fn
|
58
|
+
self._loggers = loggers
|
59
|
+
self._callbacks = callbacks
|
60
|
+
self._seeds = list(range(cycle_seed)) if cycle_seed else None
|
61
|
+
if cycle_seed is not None:
|
62
|
+
random.seed(1)
|
63
|
+
|
64
|
+
self._sample_eval = generator(jax.random.PRNGKey(0))
|
65
|
+
batchsize = tree_utils.tree_shape(self._sample_eval, 0)
|
66
|
+
T = tree_utils.tree_shape(self._sample_eval, 1)
|
67
|
+
|
68
|
+
for logger in loggers:
|
69
|
+
logger.log(dict(n_params=logger.n_params(params), batchsize=batchsize, T=T))
|
70
|
+
|
71
|
+
@property
|
72
|
+
def key(self):
|
73
|
+
if self._seeds is not None:
|
74
|
+
seed_idx = self.i_episode % len(self._seeds)
|
75
|
+
if seed_idx == 0:
|
76
|
+
random.shuffle(self._seeds)
|
77
|
+
return jax.random.PRNGKey(self._seeds[seed_idx])
|
78
|
+
else:
|
79
|
+
self._key, consume = jax.random.split(self._key)
|
80
|
+
return consume
|
81
|
+
|
82
|
+
def run(self, n_episodes: int = 1, close_afterwards: bool = True) -> bool:
|
83
|
+
# reset the kill_run flag from previous runs
|
84
|
+
send_kill_run_signal(value=False)
|
85
|
+
|
86
|
+
for _ in tqdm.tqdm(range(n_episodes)):
|
87
|
+
self.step()
|
88
|
+
|
89
|
+
if recv_kill_run_signal():
|
90
|
+
break
|
91
|
+
|
92
|
+
if close_afterwards:
|
93
|
+
self.close()
|
94
|
+
|
95
|
+
return recv_kill_run_signal()
|
96
|
+
|
97
|
+
def step(self):
|
98
|
+
self.i_episode += 1
|
99
|
+
|
100
|
+
sample_train = self._sample_eval
|
101
|
+
self._sample_eval = self._generator(self.key)
|
102
|
+
|
103
|
+
self._params, self._opt_state, loss, debug_grads = self._step_fn(
|
104
|
+
self._params, self._opt_state, sample_train[0], sample_train[1]
|
105
|
+
)
|
106
|
+
|
107
|
+
metrices = {}
|
108
|
+
metrices.update(loss)
|
109
|
+
|
110
|
+
for callback in self._callbacks:
|
111
|
+
callback.after_training_step(
|
112
|
+
self.i_episode,
|
113
|
+
metrices,
|
114
|
+
self._params,
|
115
|
+
debug_grads,
|
116
|
+
self._sample_eval,
|
117
|
+
self._loggers,
|
118
|
+
self._opt_state,
|
119
|
+
)
|
120
|
+
|
121
|
+
for logger in self._loggers:
|
122
|
+
logger.log(metrices)
|
123
|
+
|
124
|
+
return metrices
|
125
|
+
|
126
|
+
def close(self):
|
127
|
+
for callback in self._callbacks:
|
128
|
+
callback.close()
|
129
|
+
|
130
|
+
for logger in self._loggers:
|
131
|
+
logger.close()
|