imt-ring 1.6.33__py3-none-any.whl → 1.6.34__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.34.dist-info}/METADATA +1 -1
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.34.dist-info}/RECORD +5 -5
- ring/ml/train.py +5 -4
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.34.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.34.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,7 @@ ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
|
|
58
58
|
ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
|
59
59
|
ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
|
60
60
|
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
61
|
-
ring/ml/train.py,sha256
|
61
|
+
ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
|
62
62
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
63
63
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
64
64
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
86
86
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
87
87
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
88
88
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
91
|
-
imt_ring-1.6.
|
92
|
-
imt_ring-1.6.
|
89
|
+
imt_ring-1.6.34.dist-info/METADATA,sha256=D7FXQFI8b4iXaJiWkM4MvHNqzncvFh_wn4UpVK8iqMs,4251
|
90
|
+
imt_ring-1.6.34.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
91
|
+
imt_ring-1.6.34.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
92
|
+
imt_ring-1.6.34.dist-info/RECORD,,
|
ring/ml/train.py
CHANGED
@@ -39,6 +39,7 @@ def _build_step_fn(
|
|
39
39
|
filter: ml_base.AbstractFilter,
|
40
40
|
optimizer,
|
41
41
|
tbp,
|
42
|
+
skip_first_tbp_batch,
|
42
43
|
):
|
43
44
|
"""Build step function that optimizes filter parameters based on `metric_fn`.
|
44
45
|
`initial_state` has shape (pmap, vmap, state_dim)"""
|
@@ -89,6 +90,8 @@ def _build_step_fn(
|
|
89
90
|
):
|
90
91
|
(loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
|
91
92
|
debug_grads.append(grads)
|
93
|
+
if skip_first_tbp_batch and i == 0:
|
94
|
+
continue
|
92
95
|
state = jax.lax.stop_gradient(state)
|
93
96
|
params, opt_state = apply_grads(grads, params, opt_state)
|
94
97
|
|
@@ -119,6 +122,7 @@ def train_fn(
|
|
119
122
|
loss_fn: LOSS_FN = _default_loss_fn,
|
120
123
|
metrices: Optional[METRICES] = _default_metrices,
|
121
124
|
link_names: Optional[list[str]] = None,
|
125
|
+
skip_first_tbp_batch: bool = False,
|
122
126
|
) -> bool:
|
123
127
|
"""Trains RNNO
|
124
128
|
|
@@ -161,10 +165,7 @@ def train_fn(
|
|
161
165
|
opt_state = optimizer.init(filter_params)
|
162
166
|
|
163
167
|
step_fn = _build_step_fn(
|
164
|
-
loss_fn,
|
165
|
-
filter,
|
166
|
-
optimizer,
|
167
|
-
tbp=tbp,
|
168
|
+
loss_fn, filter, optimizer, tbp=tbp, skip_first_tbp_batch=skip_first_tbp_batch
|
168
169
|
)
|
169
170
|
|
170
171
|
# always log, because we also want `i_epsiode` to be logged in wandb
|
File without changes
|
File without changes
|