imt-ring 1.3.7__py3-none-any.whl → 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.7
3
+ Version: 1.3.9
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,11 +1,11 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=99GuspRH4QtRRJTAgyvS02FFxoaBptSsz_GPczX8vw0,33947
3
+ ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
7
7
  ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
8
- ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
8
+ ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
9
9
  ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
@@ -14,10 +14,10 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSe
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
- ring/algorithms/generator/base.py,sha256=AKH7GXEmRGV1kK8okiqa12uq0Ah9VYlqgdLw-99oFoQ,14840
18
- ring/algorithms/generator/batch.py,sha256=EOCX0vOxDwVOweArGsUneeeYysdSY2mFB55W052Wd9o,9161
19
- ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
- ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
17
+ ring/algorithms/generator/base.py,sha256=sr-YZkjd8pZJAI5vFG_IqOO4AEeiEYtXr8uUsPMS6Q4,14779
18
+ ring/algorithms/generator/batch.py,sha256=kNlq78W-nAtbp6Xe82UjbPY-rXX2alGLxTokTITSbAc,9226
19
+ ring/algorithms/generator/motion_artifacts.py,sha256=_kiAl1VHoX1fW5AUlXOtPBWyHIIFof_M78AP-m9f1ME,8790
20
+ ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
22
22
  ring/algorithms/generator/transforms.py,sha256=nvNDvM20tEw9Zd0ra0TxA25uf01L40Y2UKvtQmOrlGo,12782
23
23
  ring/algorithms/generator/types.py,sha256=CAhvDK5qiHnrGtkCVccB07doiz_D6lHJ35B7sW0pyZA,1110
@@ -50,13 +50,14 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=-bryExVoKJYSF_G_KYc5hI_GciIhj2xZ8WGi6TdRghw,1836
53
+ ring/ml/__init__.py,sha256=52LpEjni5lG-ov5-3ocodH-vKZxNcFMU7W9XfjDicp0,2113
54
54
  ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
55
- ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
55
+ ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
59
- ring/ml/train.py,sha256=uDW6JMdbMcjUKr3wCL2drWzDUd0Pc3BoroUwLcYoUx4,10914
58
+ ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
59
+ ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
60
+ ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
60
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
61
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
62
63
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
@@ -77,7 +78,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
78
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
79
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
80
  ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
80
- imt_ring-1.3.7.dist-info/METADATA,sha256=V6Oow_ZZwpBuHuIbyPIoKFtrhFboxMmuIPx1Rilq3-A,3104
81
- imt_ring-1.3.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.7.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.7.dist-info/RECORD,,
81
+ imt_ring-1.3.9.dist-info/METADATA,sha256=H65-QICwM4mtRPumYJbrenN74nmiMBGbeV3pecKEeOg,3104
82
+ imt_ring-1.3.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
83
+ imt_ring-1.3.9.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
84
+ imt_ring-1.3.9.dist-info/RECORD,,
@@ -1,7 +1,9 @@
1
1
  from typing import Optional, Tuple
2
+ import warnings
2
3
 
3
4
  import jax
4
5
  import jax.numpy as jnp
6
+
5
7
  from ring import algebra
6
8
  from ring import base
7
9
  from ring import maths
@@ -213,7 +215,7 @@ def forward_dynamics(
213
215
  q: jax.Array,
214
216
  qd: jax.Array,
215
217
  tau: jax.Array,
216
- mass_mat_inv: jax.Array,
218
+ # mass_mat_inv: jax.Array,
217
219
  ) -> Tuple[jax.Array, jax.Array]:
218
220
  C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
219
221
  mass_matrix = compute_mass_matrix(sys)
@@ -235,6 +237,11 @@ def forward_dynamics(
235
237
 
236
238
  mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
237
239
  else:
240
+ warnings.warn(
241
+ f"You are using `sys.mass_mat_iters`={sys.mass_mat_iters} which is >0. "
242
+ "This feature is currently not fully supported. See the local TODO."
243
+ )
244
+ mass_mat_inv = jnp.diag(jnp.ones((sys.qd_size(),)))
238
245
  mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
239
246
 
240
247
  return mass_mat_inv @ qf_smooth, mass_mat_inv
@@ -254,9 +261,8 @@ def _strapdown_integration(
254
261
  def _semi_implicit_euler_integration(
255
262
  sys: base.System, state: base.State, taus: jax.Array
256
263
  ) -> base.State:
257
- qdd, mass_mat_inv = forward_dynamics(
258
- sys, state.q, state.qd, taus, state.mass_mat_inv
259
- )
264
+ qdd, mass_mat_inv = forward_dynamics(sys, state.q, state.qd, taus)
265
+ del mass_mat_inv
260
266
  qd_next = state.qd + sys.dt * qdd
261
267
 
262
268
  q_next = []
@@ -277,7 +283,7 @@ def _semi_implicit_euler_integration(
277
283
  sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
278
284
  q_next = jnp.concatenate(q_next)
279
285
 
280
- state = state.replace(q=q_next, qd=qd_next, mass_mat_inv=mass_mat_inv)
286
+ state = state.replace(q=q_next, qd=qd_next)
281
287
  return state
282
288
 
283
289
 
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import tqdm
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -83,10 +84,14 @@ class RCMG:
83
84
  ), "If `randomize_anchors`, then only one system is expected"
84
85
  sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
85
86
 
86
- zip_sys_config = False
87
87
  if randomize_hz:
88
- zip_sys_config = True
89
88
  sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
89
+ else:
90
+ # create zip
91
+ N_sys = len(sys)
92
+ sys = sum([len(config) * [s] for s in sys], start=[])
93
+ config = N_sys * config
94
+ assert len(sys) == len(config)
90
95
 
91
96
  if sys_ml is None:
92
97
  # TODO
@@ -97,17 +102,10 @@ class RCMG:
97
102
  sys_ml = sys[0]
98
103
 
99
104
  self.gens = []
100
- if zip_sys_config:
101
- for _sys, _config in zip(sys, config):
102
- self.gens.append(
103
- partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
104
- )
105
- else:
106
- for _sys in sys:
107
- for _config in config:
108
- self.gens.append(
109
- partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
110
- )
105
+ for _sys, _config in tqdm.tqdm(
106
+ zip(sys, config), desc="building generators", total=len(sys)
107
+ ):
108
+ self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
111
109
 
112
110
  def _to_data(self, sizes, seed):
113
111
  return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
@@ -86,7 +86,11 @@ def batch_generators_eager_to_list(
86
86
 
87
87
  key = jax.random.PRNGKey(seed)
88
88
  data = []
89
- for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
89
+ for gen, size in tqdm(
90
+ zip(generators, sizes),
91
+ desc="executing generators",
92
+ total=len(sizes),
93
+ ):
90
94
 
91
95
  n_calls = _number_of_executions_required(size)
92
96
  # decrease size by n_calls times
@@ -147,7 +151,7 @@ def _data_fn_from_paths(
147
151
  paths = [utils.parse_path(p, mkdir=False) for p in paths]
148
152
 
149
153
  extensions = list(set([Path(p).suffix for p in paths]))
150
- assert len(extensions) == 1
154
+ assert len(extensions) == 1, f"{extensions}"
151
155
 
152
156
  if extensions[0] == ".h5":
153
157
  N = sum([utils.hdf5_load_length(p) for p in paths])
@@ -49,6 +49,7 @@ def inject_subsystems(
49
49
  rotational_damp: float = 0.1,
50
50
  translational_stif: float = 50.0,
51
51
  translational_damp: float = 0.1,
52
+ disable_warning: bool = False,
52
53
  **kwargs,
53
54
  ) -> base.System:
54
55
  imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
@@ -92,10 +93,11 @@ def inject_subsystems(
92
93
  # TODO set all joint_params to zeros; they can not be preserved anyways and
93
94
  # otherwise many warnings will be rose
94
95
  # instead warn explicitly once now and move on
95
- warnings.warn(
96
- "`sys.links.joint_params` has been set to zero, this might lead to "
97
- "unexpected behaviour unless you use `randomize_joint_params`"
98
- )
96
+ if not disable_warning:
97
+ warnings.warn(
98
+ "`sys.links.joint_params` has been set to zero, this might lead to "
99
+ "unexpected behaviour unless you use `randomize_joint_params`"
100
+ )
99
101
  joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
100
102
  sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
101
103
 
@@ -180,9 +182,13 @@ def setup_fn_randomize_damping_stiffness_factory(
180
182
  link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
181
183
  link_damping = link_damping.at[slice].set(damp)
182
184
 
183
- assert len(imus_surely_rigid) == len(triggered_surely_rigid)
185
+ assert len(imus_surely_rigid) == len(
186
+ triggered_surely_rigid
187
+ ), f"{imus_surely_rigid}, {triggered_surely_rigid}"
184
188
  for imu_surely_rigid in imus_surely_rigid:
185
- assert imu_surely_rigid in triggered_surely_rigid
189
+ assert (
190
+ imu_surely_rigid in triggered_surely_rigid
191
+ ), f"{imus_surely_rigid} not in {triggered_surely_rigid}"
186
192
 
187
193
  return sys.replace(
188
194
  link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
@@ -4,6 +4,7 @@ from typing import Optional
4
4
  from flax import struct
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+
7
8
  from ring import base
8
9
  from ring.algorithms import dynamics
9
10
  from ring.algorithms import jcalc
@@ -49,7 +50,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
49
50
  assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
50
51
  assert sys.qd_size() == P.size
51
52
  if D is not None:
52
- sys.qd_size() == D.size
53
+ assert sys.qd_size() == D.size
53
54
 
54
55
  q_ref_as_dict = {}
55
56
  qd_ref_as_dict = {}
ring/base.py CHANGED
@@ -997,13 +997,11 @@ class State(_Base):
997
997
  q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
998
998
  qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
999
999
  x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
1000
- mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
1001
1000
  """
1002
1001
 
1003
1002
  q: jax.Array
1004
1003
  qd: jax.Array
1005
1004
  x: Transform
1006
- mass_mat_inv: jax.Array
1007
1005
 
1008
1006
  @classmethod
1009
1007
  def create(
@@ -1057,4 +1055,4 @@ class State(_Base):
1057
1055
  if x is None:
1058
1056
  x = Transform.zero((sys.num_links(),))
1059
1057
 
1060
- return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))
1058
+ return cls(q, qd, x)
ring/ml/__init__.py CHANGED
@@ -3,6 +3,7 @@ from . import callbacks
3
3
  from . import ml_utils
4
4
  from . import optimizer
5
5
  from . import ringnet
6
+ from . import rnno_v1
6
7
  from . import train
7
8
  from . import training_loop
8
9
  from .base import AbstractFilter
@@ -42,17 +43,28 @@ def RNNO(
42
43
  params=None,
43
44
  eval: bool = True,
44
45
  samp_freq: float | None = None,
46
+ v1: bool = False,
45
47
  **kwargs,
46
48
  ):
47
49
  assert "message_dim" not in kwargs
48
50
  assert "link_output_normalize" not in kwargs
49
51
  assert "link_output_dim" not in kwargs
50
52
 
53
+ if v1:
54
+ kwargs.update(
55
+ dict(forward_factory=rnno_v1.rnno_v1_forward_factory, output_dim=output_dim)
56
+ )
57
+ else:
58
+ kwargs.update(
59
+ dict(
60
+ message_dim=0,
61
+ link_output_normalize=False,
62
+ link_output_dim=output_dim,
63
+ )
64
+ )
65
+
51
66
  ringnet = RING( # noqa: F811
52
67
  params=params,
53
- message_dim=0,
54
- link_output_normalize=False,
55
- link_output_dim=output_dim,
56
68
  **kwargs,
57
69
  )
58
70
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
ring/ml/callbacks.py CHANGED
@@ -245,7 +245,8 @@ class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
245
245
  else:
246
246
  value = "{:.2f}".format(ele.value).replace(".", ",")
247
247
  filename = parse_path(
248
- self.path_to_file + f"_episode={ele.episode}_value={value}",
248
+ str(Path(self.path_to_file).with_suffix(""))
249
+ + f"_episode={ele.episode}_value={value}",
249
250
  extension="pickle",
250
251
  )
251
252
 
@@ -404,7 +405,7 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
404
405
  # only checkpoint if run has been killed
405
406
  if training_loop.recv_kill_run_signal():
406
407
  path = parse_path(
407
- "~/.xxy_checkpoints", ml_utils.unique_id(), extension="pickle"
408
+ "~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
408
409
  )
409
410
  data = {"params": self.params, "opt_state": self.opt_state}
410
411
  pickle_save(
ring/ml/ringnet.py CHANGED
@@ -191,8 +191,16 @@ class LSTM(hk.RNNCore):
191
191
 
192
192
 
193
193
  class RING(ml_base.AbstractFilter):
194
- def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):
195
- self.forward_lam_factory = partial(make_ring, **kwargs)
194
+ def __init__(
195
+ self,
196
+ params=None,
197
+ lam=None,
198
+ jit: bool = True,
199
+ name=None,
200
+ forward_factory=make_ring,
201
+ **kwargs,
202
+ ):
203
+ self.forward_lam_factory = partial(forward_factory, **kwargs)
196
204
  self.params = self._load_params(params)
197
205
  self.lam = lam
198
206
  self._name = name
ring/ml/rnno_v1.py ADDED
@@ -0,0 +1,41 @@
1
+ from typing import Optional, Sequence
2
+
3
+ import haiku as hk
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+
8
+ def rnno_v1_forward_factory(
9
+ output_dim: int,
10
+ rnn_layers: Sequence[int] = (400, 300),
11
+ linear_layers: Sequence[int] = (200, 100, 50, 50, 25, 25),
12
+ layernorm: bool = True,
13
+ act_fn_linear=jax.nn.relu,
14
+ act_fn_rnn=jax.nn.elu,
15
+ lam: Optional[tuple[int]] = None,
16
+ ):
17
+ # unused
18
+ del lam
19
+
20
+ @hk.without_apply_rng
21
+ @hk.transform_with_state
22
+ def forward_fn(X):
23
+ assert X.shape[-2] == 1
24
+
25
+ for i, n_units in enumerate(rnn_layers):
26
+ state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
27
+ X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
28
+ hk.set_state(f"rnn_{i}", state)
29
+
30
+ if layernorm:
31
+ X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)
32
+ X = act_fn_rnn(X)
33
+
34
+ for n_units in linear_layers:
35
+ X = hk.Linear(n_units)(X)
36
+ X = act_fn_linear(X)
37
+
38
+ y = hk.Linear(output_dim)(X)
39
+ return y[..., None, :]
40
+
41
+ return forward_fn
ring/ml/train.py CHANGED
@@ -15,7 +15,6 @@ from ring.ml import ml_utils
15
15
  from ring.ml import training_loop
16
16
  from ring.utils import distribute_batchsize
17
17
  from ring.utils import expand_batchsize
18
- from ring.utils import parse_path
19
18
  from ring.utils import pickle_load
20
19
  import wandb
21
20
 
@@ -217,7 +216,7 @@ def train_fn(
217
216
 
218
217
  callbacks_all.append(
219
218
  ml_callbacks.SaveParamsTrainingLoopCallback(
220
- path_to_file=parse_path(callback_save_params, extension=""),
219
+ path_to_file=callback_save_params,
221
220
  last_n_params=3,
222
221
  track_metrices=callback_save_params_track_metrices,
223
222
  cleanup=False,