imt-ring 1.3.6__py3-none-any.whl → 1.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,7 +56,7 @@ ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
|
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
|
59
|
-
ring/ml/train.py,sha256=
|
59
|
+
ring/ml/train.py,sha256=uDW6JMdbMcjUKr3wCL2drWzDUd0Pc3BoroUwLcYoUx4,10914
|
60
60
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
61
61
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
62
62
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
77
77
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
78
78
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
79
79
|
ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
|
80
|
-
imt_ring-1.3.
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
80
|
+
imt_ring-1.3.7.dist-info/METADATA,sha256=V6Oow_ZZwpBuHuIbyPIoKFtrhFboxMmuIPx1Rilq3-A,3104
|
81
|
+
imt_ring-1.3.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
82
|
+
imt_ring-1.3.7.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
83
|
+
imt_ring-1.3.7.dist-info/RECORD,,
|
ring/ml/train.py
CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import optax
|
8
|
+
import tree_utils
|
9
|
+
|
8
10
|
from ring import maths
|
9
11
|
from ring.algorithms.generator import types
|
10
12
|
from ring.ml import base as ml_base
|
@@ -15,8 +17,6 @@ from ring.utils import distribute_batchsize
|
|
15
17
|
from ring.utils import expand_batchsize
|
16
18
|
from ring.utils import parse_path
|
17
19
|
from ring.utils import pickle_load
|
18
|
-
import tree_utils
|
19
|
-
|
20
20
|
import wandb
|
21
21
|
|
22
22
|
# (T, N, F) -> Scalar
|
@@ -142,15 +142,17 @@ def train_fn(
|
|
142
142
|
Wether or not the training run was killed by a callback.
|
143
143
|
"""
|
144
144
|
|
145
|
+
filter = filter.nojit()
|
146
|
+
|
145
147
|
if checkpoint is not None:
|
146
148
|
checkpoint = Path(checkpoint).with_suffix(".pickle")
|
147
149
|
recv_checkpoint: dict = pickle_load(checkpoint)
|
148
|
-
|
150
|
+
filter_params = recv_checkpoint["params"]
|
149
151
|
opt_state = recv_checkpoint["opt_state"]
|
152
|
+
del recv_checkpoint
|
153
|
+
else:
|
154
|
+
filter_params = filter.search_attr("params")
|
150
155
|
|
151
|
-
filter = filter.nojit()
|
152
|
-
|
153
|
-
filter_params = filter.search_attr("params")
|
154
156
|
if filter_params is None:
|
155
157
|
X, _ = generator(jax.random.PRNGKey(1))
|
156
158
|
filter_params, _ = filter.init(X=X, seed=seed_network)
|
File without changes
|
File without changes
|