imt-ring 1.6.33__py3-none-any.whl → 1.6.34__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.33
3
+ Version: 1.6.34
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -58,7 +58,7 @@ ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
- ring/ml/train.py,sha256=-6SzQKjIgktgRjaXKVg_1dqcBmAJggZSVwDnau1FnxI,10832
61
+ ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
62
62
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.33.dist-info/METADATA,sha256=FYe4G7jx8u4IblPmrFrpnqCxL3Nv-ITa6LG9fVGaOng,4251
90
- imt_ring-1.6.33.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.33.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.33.dist-info/RECORD,,
89
+ imt_ring-1.6.34.dist-info/METADATA,sha256=D7FXQFI8b4iXaJiWkM4MvHNqzncvFh_wn4UpVK8iqMs,4251
90
+ imt_ring-1.6.34.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.34.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.34.dist-info/RECORD,,
ring/ml/train.py CHANGED
@@ -39,6 +39,7 @@ def _build_step_fn(
39
39
  filter: ml_base.AbstractFilter,
40
40
  optimizer,
41
41
  tbp,
42
+ skip_first_tbp_batch,
42
43
  ):
43
44
  """Build step function that optimizes filter parameters based on `metric_fn`.
44
45
  `initial_state` has shape (pmap, vmap, state_dim)"""
@@ -89,6 +90,8 @@ def _build_step_fn(
89
90
  ):
90
91
  (loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
91
92
  debug_grads.append(grads)
93
+ if skip_first_tbp_batch and i == 0:
94
+ continue
92
95
  state = jax.lax.stop_gradient(state)
93
96
  params, opt_state = apply_grads(grads, params, opt_state)
94
97
 
@@ -119,6 +122,7 @@ def train_fn(
119
122
  loss_fn: LOSS_FN = _default_loss_fn,
120
123
  metrices: Optional[METRICES] = _default_metrices,
121
124
  link_names: Optional[list[str]] = None,
125
+ skip_first_tbp_batch: bool = False,
122
126
  ) -> bool:
123
127
  """Trains RNNO
124
128
 
@@ -161,10 +165,7 @@ def train_fn(
161
165
  opt_state = optimizer.init(filter_params)
162
166
 
163
167
  step_fn = _build_step_fn(
164
- loss_fn,
165
- filter,
166
- optimizer,
167
- tbp=tbp,
168
+ loss_fn, filter, optimizer, tbp=tbp, skip_first_tbp_batch=skip_first_tbp_batch
168
169
  )
169
170
 
170
171
  # always log, because we also want `i_epsiode` to be logged in wandb