imt-ring 1.3.6__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.6
3
+ Version: 1.3.7
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -56,7 +56,7 @@ ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
59
- ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
59
+ ring/ml/train.py,sha256=uDW6JMdbMcjUKr3wCL2drWzDUd0Pc3BoroUwLcYoUx4,10914
60
60
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
61
61
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
62
62
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
79
  ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
80
- imt_ring-1.3.6.dist-info/METADATA,sha256=E5mVtL-2o6-U-Ov56yd4M0RVQs0VJoLSezHpBWGtleg,3104
81
- imt_ring-1.3.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.6.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.6.dist-info/RECORD,,
80
+ imt_ring-1.3.7.dist-info/METADATA,sha256=V6Oow_ZZwpBuHuIbyPIoKFtrhFboxMmuIPx1Rilq3-A,3104
81
+ imt_ring-1.3.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.7.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.7.dist-info/RECORD,,
ring/ml/train.py CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import optax
8
+ import tree_utils
9
+
8
10
  from ring import maths
9
11
  from ring.algorithms.generator import types
10
12
  from ring.ml import base as ml_base
@@ -15,8 +17,6 @@ from ring.utils import distribute_batchsize
15
17
  from ring.utils import expand_batchsize
16
18
  from ring.utils import parse_path
17
19
  from ring.utils import pickle_load
18
- import tree_utils
19
-
20
20
  import wandb
21
21
 
22
22
  # (T, N, F) -> Scalar
@@ -142,15 +142,17 @@ def train_fn(
142
142
  Wether or not the training run was killed by a callback.
143
143
  """
144
144
 
145
+ filter = filter.nojit()
146
+
145
147
  if checkpoint is not None:
146
148
  checkpoint = Path(checkpoint).with_suffix(".pickle")
147
149
  recv_checkpoint: dict = pickle_load(checkpoint)
148
- filter.params = recv_checkpoint["params"]
150
+ filter_params = recv_checkpoint["params"]
149
151
  opt_state = recv_checkpoint["opt_state"]
152
+ del recv_checkpoint
153
+ else:
154
+ filter_params = filter.search_attr("params")
150
155
 
151
- filter = filter.nojit()
152
-
153
- filter_params = filter.search_attr("params")
154
156
  if filter_params is None:
155
157
  X, _ = generator(jax.random.PRNGKey(1))
156
158
  filter_params, _ = filter.init(X=X, seed=seed_network)