imt-ring 1.6.17__py3-none-any.whl → 1.6.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.6.17.dist-info → imt_ring-1.6.19.dist-info}/METADATA +1 -1
- {imt_ring-1.6.17.dist-info → imt_ring-1.6.19.dist-info}/RECORD +6 -5
- ring/ml/ml_utils.py +22 -4
- ring/utils/dataloader_torch.py +62 -0
- {imt_ring-1.6.17.dist-info → imt_ring-1.6.19.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.17.dist-info → imt_ring-1.6.19.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
|
54
54
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
55
55
|
ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
56
56
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
57
|
-
ring/ml/ml_utils.py,sha256=
|
57
|
+
ring/ml/ml_utils.py,sha256=xqy9BnLy8IKVqkFS9mlZsGJXSbThI9zZxZ5rhl8LSI8,7144
|
58
58
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
59
59
|
ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
|
60
60
|
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
@@ -78,6 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
|
78
78
|
ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
|
79
79
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
80
80
|
ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
|
81
|
+
ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
|
81
82
|
ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
82
83
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
83
84
|
ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
|
@@ -85,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
85
86
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
86
87
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
87
88
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
88
|
-
imt_ring-1.6.
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
91
|
-
imt_ring-1.6.
|
89
|
+
imt_ring-1.6.19.dist-info/METADATA,sha256=BZiSbypG96pTKHUSc1NDZYcyiRf1-ynsnu6rAsVdiRU,3833
|
90
|
+
imt_ring-1.6.19.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
91
|
+
imt_ring-1.6.19.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
92
|
+
imt_ring-1.6.19.dist-info/RECORD,,
|
ring/ml/ml_utils.py
CHANGED
@@ -3,6 +3,7 @@ from functools import partial
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
5
|
import pickle
|
6
|
+
import shutil
|
6
7
|
import time
|
7
8
|
from typing import Optional, Protocol
|
8
9
|
import warnings
|
@@ -107,7 +108,7 @@ class WandbLogger(MixinLogger):
|
|
107
108
|
wandb.log(data)
|
108
109
|
|
109
110
|
def log_params(self, path: str):
|
110
|
-
|
111
|
+
self.wandb_save(path)
|
111
112
|
|
112
113
|
def log_video(
|
113
114
|
self,
|
@@ -117,7 +118,7 @@ class WandbLogger(MixinLogger):
|
|
117
118
|
step: Optional[int] = None,
|
118
119
|
):
|
119
120
|
# TODO >>>
|
120
|
-
|
121
|
+
self.wandb_save(path)
|
121
122
|
return
|
122
123
|
# <<<
|
123
124
|
data = {"video": wandb.Video(path, caption=caption, fps=fps)}
|
@@ -127,10 +128,10 @@ class WandbLogger(MixinLogger):
|
|
127
128
|
|
128
129
|
def log_image(self, path: str, caption: Optional[str] = None):
|
129
130
|
# wandb.log({"image": wandb.Image(path, caption=caption)})
|
130
|
-
|
131
|
+
self.wandb_save(path)
|
131
132
|
|
132
133
|
def log_txt(self, path: str, wait: bool = True):
|
133
|
-
|
134
|
+
self.wandb_save(path)
|
134
135
|
# TODO: `wandb` is not async at all?
|
135
136
|
if wait:
|
136
137
|
time.sleep(3)
|
@@ -138,6 +139,23 @@ class WandbLogger(MixinLogger):
|
|
138
139
|
def close(self):
|
139
140
|
wandb.run.finish()
|
140
141
|
|
142
|
+
@staticmethod
|
143
|
+
def wandb_save(path):
|
144
|
+
if wandb.run is not None and wandb.run.settings._offline:
|
145
|
+
# Create a dedicated directory in the WandB run directory to store copies
|
146
|
+
# of files
|
147
|
+
destination_dir = os.path.join(wandb.run.dir, "copied_files")
|
148
|
+
os.makedirs(destination_dir, exist_ok=True)
|
149
|
+
|
150
|
+
# Copy the file to this new location
|
151
|
+
copied_file_path = os.path.join(destination_dir, os.path.basename(path))
|
152
|
+
shutil.copy2(path, copied_file_path)
|
153
|
+
|
154
|
+
# Use wandb.save to save the copied file (now a true copy)
|
155
|
+
wandb.save(copied_file_path)
|
156
|
+
else:
|
157
|
+
wandb.save(path, policy="now")
|
158
|
+
|
141
159
|
|
142
160
|
def _flatten_convert_filter_nested_dict(
|
143
161
|
metrices: NestedDict, filter_nan_inf: bool = True
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import torch
|
5
|
+
from torch.utils.data import DataLoader
|
6
|
+
from torch.utils.data import Dataset
|
7
|
+
from tree_utils import PyTree
|
8
|
+
|
9
|
+
from ring.utils import parse_path
|
10
|
+
from ring.utils import pickle_load
|
11
|
+
|
12
|
+
|
13
|
+
class FolderOfPickleFilesDataset(Dataset):
|
14
|
+
def __init__(self, path, transform=None):
|
15
|
+
self.files = self.listdir(path)
|
16
|
+
self.transform = transform
|
17
|
+
self.N = len(self.files)
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return self.N
|
21
|
+
|
22
|
+
def __getitem__(self, idx: int):
|
23
|
+
element = pickle_load(self.files[idx])
|
24
|
+
if self.transform is not None:
|
25
|
+
element = self.transform(element)
|
26
|
+
return element
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def listdir(path: str) -> list:
|
30
|
+
return [parse_path(path, file) for file in os.listdir(path)]
|
31
|
+
|
32
|
+
|
33
|
+
def dataset_to_generator(
|
34
|
+
dataset: Dataset,
|
35
|
+
batch_size: int,
|
36
|
+
shuffle=True,
|
37
|
+
seed: int = 1,
|
38
|
+
**kwargs,
|
39
|
+
):
|
40
|
+
torch.manual_seed(seed)
|
41
|
+
|
42
|
+
dl = DataLoader(
|
43
|
+
dataset,
|
44
|
+
batch_size=batch_size,
|
45
|
+
shuffle=shuffle,
|
46
|
+
multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
|
47
|
+
**kwargs,
|
48
|
+
)
|
49
|
+
dl_iter = iter(dl)
|
50
|
+
|
51
|
+
def to_numpy(tree: PyTree[torch.Tensor]):
|
52
|
+
return jax.tree_map(lambda tensor: tensor.numpy(), tree)
|
53
|
+
|
54
|
+
def generator(_):
|
55
|
+
nonlocal dl, dl_iter
|
56
|
+
try:
|
57
|
+
return to_numpy(next(dl_iter))
|
58
|
+
except StopIteration:
|
59
|
+
dl_iter = iter(dl)
|
60
|
+
return to_numpy(next(dl_iter))
|
61
|
+
|
62
|
+
return generator
|
File without changes
|
File without changes
|