imt-ring 1.6.17__py3-none-any.whl → 1.6.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.17
3
+ Version: 1.6.19
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -54,7 +54,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
55
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
56
56
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
57
- ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
+ ring/ml/ml_utils.py,sha256=xqy9BnLy8IKVqkFS9mlZsGJXSbThI9zZxZ5rhl8LSI8,7144
58
58
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
@@ -78,6 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
+ ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
81
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
82
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
83
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -85,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
85
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
86
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
87
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
88
- imt_ring-1.6.17.dist-info/METADATA,sha256=j0IKIyc6qAgz9059Z4-b46hi7qTn_ffI7oLw_OrD_Tk,3833
89
- imt_ring-1.6.17.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
90
- imt_ring-1.6.17.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
91
- imt_ring-1.6.17.dist-info/RECORD,,
89
+ imt_ring-1.6.19.dist-info/METADATA,sha256=BZiSbypG96pTKHUSc1NDZYcyiRf1-ynsnu6rAsVdiRU,3833
90
+ imt_ring-1.6.19.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
91
+ imt_ring-1.6.19.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.19.dist-info/RECORD,,
ring/ml/ml_utils.py CHANGED
@@ -3,6 +3,7 @@ from functools import partial
3
3
  import os
4
4
  from pathlib import Path
5
5
  import pickle
6
+ import shutil
6
7
  import time
7
8
  from typing import Optional, Protocol
8
9
  import warnings
@@ -107,7 +108,7 @@ class WandbLogger(MixinLogger):
107
108
  wandb.log(data)
108
109
 
109
110
  def log_params(self, path: str):
110
- wandb.save(path, policy="now")
111
+ self.wandb_save(path)
111
112
 
112
113
  def log_video(
113
114
  self,
@@ -117,7 +118,7 @@ class WandbLogger(MixinLogger):
117
118
  step: Optional[int] = None,
118
119
  ):
119
120
  # TODO >>>
120
- wandb.save(path, policy="now")
121
+ self.wandb_save(path)
121
122
  return
122
123
  # <<<
123
124
  data = {"video": wandb.Video(path, caption=caption, fps=fps)}
@@ -127,10 +128,10 @@ class WandbLogger(MixinLogger):
127
128
 
128
129
  def log_image(self, path: str, caption: Optional[str] = None):
129
130
  # wandb.log({"image": wandb.Image(path, caption=caption)})
130
- wandb.save(path, policy="now")
131
+ self.wandb_save(path)
131
132
 
132
133
  def log_txt(self, path: str, wait: bool = True):
133
- wandb.save(path, policy="now")
134
+ self.wandb_save(path)
134
135
  # TODO: `wandb` is not async at all?
135
136
  if wait:
136
137
  time.sleep(3)
@@ -138,6 +139,23 @@ class WandbLogger(MixinLogger):
138
139
  def close(self):
139
140
  wandb.run.finish()
140
141
 
142
+ @staticmethod
143
+ def wandb_save(path):
144
+ if wandb.run is not None and wandb.run.settings._offline:
145
+ # Create a dedicated directory in the WandB run directory to store copies
146
+ # of files
147
+ destination_dir = os.path.join(wandb.run.dir, "copied_files")
148
+ os.makedirs(destination_dir, exist_ok=True)
149
+
150
+ # Copy the file to this new location
151
+ copied_file_path = os.path.join(destination_dir, os.path.basename(path))
152
+ shutil.copy2(path, copied_file_path)
153
+
154
+ # Use wandb.save to save the copied file (now a true copy)
155
+ wandb.save(copied_file_path)
156
+ else:
157
+ wandb.save(path, policy="now")
158
+
141
159
 
142
160
  def _flatten_convert_filter_nested_dict(
143
161
  metrices: NestedDict, filter_nan_inf: bool = True
@@ -0,0 +1,62 @@
1
+ import os
2
+
3
+ import jax
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import Dataset
7
+ from tree_utils import PyTree
8
+
9
+ from ring.utils import parse_path
10
+ from ring.utils import pickle_load
11
+
12
+
13
+ class FolderOfPickleFilesDataset(Dataset):
14
+ def __init__(self, path, transform=None):
15
+ self.files = self.listdir(path)
16
+ self.transform = transform
17
+ self.N = len(self.files)
18
+
19
+ def __len__(self):
20
+ return self.N
21
+
22
+ def __getitem__(self, idx: int):
23
+ element = pickle_load(self.files[idx])
24
+ if self.transform is not None:
25
+ element = self.transform(element)
26
+ return element
27
+
28
+ @staticmethod
29
+ def listdir(path: str) -> list:
30
+ return [parse_path(path, file) for file in os.listdir(path)]
31
+
32
+
33
+ def dataset_to_generator(
34
+ dataset: Dataset,
35
+ batch_size: int,
36
+ shuffle=True,
37
+ seed: int = 1,
38
+ **kwargs,
39
+ ):
40
+ torch.manual_seed(seed)
41
+
42
+ dl = DataLoader(
43
+ dataset,
44
+ batch_size=batch_size,
45
+ shuffle=shuffle,
46
+ multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
47
+ **kwargs,
48
+ )
49
+ dl_iter = iter(dl)
50
+
51
+ def to_numpy(tree: PyTree[torch.Tensor]):
52
+ return jax.tree_map(lambda tensor: tensor.numpy(), tree)
53
+
54
+ def generator(_):
55
+ nonlocal dl, dl_iter
56
+ try:
57
+ return to_numpy(next(dl_iter))
58
+ except StopIteration:
59
+ dl_iter = iter(dl)
60
+ return to_numpy(next(dl_iter))
61
+
62
+ return generator