imt-ring 1.6.17__py3-none-any.whl → 1.6.19__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.17
3
+ Version: 1.6.19
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -54,7 +54,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
55
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
56
56
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
57
- ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
+ ring/ml/ml_utils.py,sha256=xqy9BnLy8IKVqkFS9mlZsGJXSbThI9zZxZ5rhl8LSI8,7144
58
58
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
@@ -78,6 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
+ ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
81
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
82
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
83
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -85,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
85
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
86
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
87
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
88
- imt_ring-1.6.17.dist-info/METADATA,sha256=j0IKIyc6qAgz9059Z4-b46hi7qTn_ffI7oLw_OrD_Tk,3833
89
- imt_ring-1.6.17.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
90
- imt_ring-1.6.17.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
91
- imt_ring-1.6.17.dist-info/RECORD,,
89
+ imt_ring-1.6.19.dist-info/METADATA,sha256=BZiSbypG96pTKHUSc1NDZYcyiRf1-ynsnu6rAsVdiRU,3833
90
+ imt_ring-1.6.19.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
91
+ imt_ring-1.6.19.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.19.dist-info/RECORD,,
ring/ml/ml_utils.py CHANGED
@@ -3,6 +3,7 @@ from functools import partial
3
3
  import os
4
4
  from pathlib import Path
5
5
  import pickle
6
+ import shutil
6
7
  import time
7
8
  from typing import Optional, Protocol
8
9
  import warnings
@@ -107,7 +108,7 @@ class WandbLogger(MixinLogger):
107
108
  wandb.log(data)
108
109
 
109
110
  def log_params(self, path: str):
110
- wandb.save(path, policy="now")
111
+ self.wandb_save(path)
111
112
 
112
113
  def log_video(
113
114
  self,
@@ -117,7 +118,7 @@ class WandbLogger(MixinLogger):
117
118
  step: Optional[int] = None,
118
119
  ):
119
120
  # TODO >>>
120
- wandb.save(path, policy="now")
121
+ self.wandb_save(path)
121
122
  return
122
123
  # <<<
123
124
  data = {"video": wandb.Video(path, caption=caption, fps=fps)}
@@ -127,10 +128,10 @@ class WandbLogger(MixinLogger):
127
128
 
128
129
  def log_image(self, path: str, caption: Optional[str] = None):
129
130
  # wandb.log({"image": wandb.Image(path, caption=caption)})
130
- wandb.save(path, policy="now")
131
+ self.wandb_save(path)
131
132
 
132
133
  def log_txt(self, path: str, wait: bool = True):
133
- wandb.save(path, policy="now")
134
+ self.wandb_save(path)
134
135
  # TODO: `wandb` is not async at all?
135
136
  if wait:
136
137
  time.sleep(3)
@@ -138,6 +139,23 @@ class WandbLogger(MixinLogger):
138
139
  def close(self):
139
140
  wandb.run.finish()
140
141
 
142
+ @staticmethod
143
+ def wandb_save(path):
144
+ if wandb.run is not None and wandb.run.settings._offline:
145
+ # Create a dedicated directory in the WandB run directory to store copies
146
+ # of files
147
+ destination_dir = os.path.join(wandb.run.dir, "copied_files")
148
+ os.makedirs(destination_dir, exist_ok=True)
149
+
150
+ # Copy the file to this new location
151
+ copied_file_path = os.path.join(destination_dir, os.path.basename(path))
152
+ shutil.copy2(path, copied_file_path)
153
+
154
+ # Use wandb.save to save the copied file (now a true copy)
155
+ wandb.save(copied_file_path)
156
+ else:
157
+ wandb.save(path, policy="now")
158
+
141
159
 
142
160
  def _flatten_convert_filter_nested_dict(
143
161
  metrices: NestedDict, filter_nan_inf: bool = True
@@ -0,0 +1,62 @@
1
+ import os
2
+
3
+ import jax
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import Dataset
7
+ from tree_utils import PyTree
8
+
9
+ from ring.utils import parse_path
10
+ from ring.utils import pickle_load
11
+
12
+
13
+ class FolderOfPickleFilesDataset(Dataset):
14
+ def __init__(self, path, transform=None):
15
+ self.files = self.listdir(path)
16
+ self.transform = transform
17
+ self.N = len(self.files)
18
+
19
+ def __len__(self):
20
+ return self.N
21
+
22
+ def __getitem__(self, idx: int):
23
+ element = pickle_load(self.files[idx])
24
+ if self.transform is not None:
25
+ element = self.transform(element)
26
+ return element
27
+
28
+ @staticmethod
29
+ def listdir(path: str) -> list:
30
+ return [parse_path(path, file) for file in os.listdir(path)]
31
+
32
+
33
+ def dataset_to_generator(
34
+ dataset: Dataset,
35
+ batch_size: int,
36
+ shuffle=True,
37
+ seed: int = 1,
38
+ **kwargs,
39
+ ):
40
+ torch.manual_seed(seed)
41
+
42
+ dl = DataLoader(
43
+ dataset,
44
+ batch_size=batch_size,
45
+ shuffle=shuffle,
46
+ multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
47
+ **kwargs,
48
+ )
49
+ dl_iter = iter(dl)
50
+
51
+ def to_numpy(tree: PyTree[torch.Tensor]):
52
+ return jax.tree_map(lambda tensor: tensor.numpy(), tree)
53
+
54
+ def generator(_):
55
+ nonlocal dl, dl_iter
56
+ try:
57
+ return to_numpy(next(dl_iter))
58
+ except StopIteration:
59
+ dl_iter = iter(dl)
60
+ return to_numpy(next(dl_iter))
61
+
62
+ return generator