egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
from datetime import timedelta
|
|
4
|
+
|
|
5
|
+
import accelerate
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import json
|
|
9
|
+
|
|
10
|
+
from . import *
|
|
11
|
+
|
|
12
|
+
# import accelerate
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def set_seed_everywhere(seed):
|
|
16
|
+
torch.manual_seed(seed)
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
torch.cuda.manual_seed_all(seed)
|
|
19
|
+
np.random.seed(seed)
|
|
20
|
+
random.seed(seed)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_demos(cfg):
|
|
24
|
+
train_files_list = cfg.dataset.train.config.trajectory_roots.trajectory_root_path
|
|
25
|
+
val_files_list = cfg.dataset.test.config.trajectory_roots.trajectory_root_path
|
|
26
|
+
|
|
27
|
+
train_demos = json.load(open(train_files_list))
|
|
28
|
+
val_demos = json.load(open(val_files_list))
|
|
29
|
+
|
|
30
|
+
return {"train": train_demos, "val": val_demos}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def setup_accelerate(experiment, num_epochs, task_title, save_every):
|
|
34
|
+
ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
35
|
+
|
|
36
|
+
accelerator = accelerate.Accelerator(
|
|
37
|
+
log_with=["wandb"],
|
|
38
|
+
kwargs_handlers=[
|
|
39
|
+
accelerate.InitProcessGroupKwargs(timeout=timedelta(hours=1.5)),
|
|
40
|
+
ddp_kwargs,
|
|
41
|
+
],
|
|
42
|
+
)
|
|
43
|
+
accelerator.init_trackers(
|
|
44
|
+
"experiment",
|
|
45
|
+
config={
|
|
46
|
+
"num_epochs": 1000,
|
|
47
|
+
"task_title": task_title,
|
|
48
|
+
"save_every": 20,
|
|
49
|
+
},
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return accelerator, ddp_kwargs
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class AverageMeter:
|
|
56
|
+
def __init__(self, value=None):
|
|
57
|
+
super().__init__()
|
|
58
|
+
if value is None:
|
|
59
|
+
self._sum = 0
|
|
60
|
+
self._count = 0
|
|
61
|
+
else:
|
|
62
|
+
self._sum = value
|
|
63
|
+
self._count = 1
|
|
64
|
+
|
|
65
|
+
def update(self, value, n=1):
|
|
66
|
+
self._sum += value * n
|
|
67
|
+
self._count += n
|
|
68
|
+
self._avg = self._sum / self._count
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def avg(self):
|
|
72
|
+
return self._avg
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AveragingDict(OrderedDict):
|
|
76
|
+
def __init__(self, name="train", *args, **kwargs):
|
|
77
|
+
super().__init__(*args, **kwargs)
|
|
78
|
+
self.name = name
|
|
79
|
+
|
|
80
|
+
def update(self, other):
|
|
81
|
+
for k, v in other.items():
|
|
82
|
+
if k not in self:
|
|
83
|
+
self[k] = AverageMeter(v)
|
|
84
|
+
self[k].update(v)
|
|
85
|
+
|
|
86
|
+
def __str__(self):
|
|
87
|
+
return ", ".join(f"{self.name}/{k}: {v.avg:.2E}" for k, v in self.items())
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def full_summary(self):
|
|
91
|
+
return {f"{self.name}/{k}": v.avg for k, v in self.items()}
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def summary(self):
|
|
95
|
+
return {f"{self.name}/{k}": f"{v.avg:.2E}" for k, v in self.items()}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Callbacks:
|
|
99
|
+
"""
|
|
100
|
+
This class is used to implement callback functions inside training framework that
|
|
101
|
+
passes, training params to callback functions inside model and loss objects.
|
|
102
|
+
|
|
103
|
+
This is useful for custom logging inside models and losses like adding image
|
|
104
|
+
visualisation, plotting hyperparameters like learning rate, etc.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self):
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
def set_workspace(self, workspace, accelerator):
|
|
111
|
+
self.workspace = workspace
|
|
112
|
+
self.accelerator = accelerator
|
|
113
|
+
self._model_has_begin_epoch = hasattr(
|
|
114
|
+
self.accelerator.unwrap_model(self.workspace.model), "_begin_epoch"
|
|
115
|
+
)
|
|
116
|
+
self._model_has_begin_batch = hasattr(
|
|
117
|
+
self.accelerator.unwrap_model(self.workspace.model), "_begin_batch"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self._loss_fn_has_begin_epoch = hasattr(
|
|
121
|
+
self.accelerator.unwrap_model(self.workspace.loss_fn), "_begin_epoch"
|
|
122
|
+
)
|
|
123
|
+
self._loss_fn_has_begin_batch = hasattr(
|
|
124
|
+
self.accelerator.unwrap_model(self.workspace.loss_fn), "_begin_batch"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def begin_epoch(self):
|
|
128
|
+
self.batch_step = 0
|
|
129
|
+
|
|
130
|
+
if self._model_has_begin_epoch:
|
|
131
|
+
model = self.accelerator.unwrap_model(self.workspace.model)
|
|
132
|
+
is_main_process = self.accelerator.is_local_main_process
|
|
133
|
+
begin_epoch_log = model._begin_epoch(
|
|
134
|
+
epoch=self.workspace._epoch,
|
|
135
|
+
epochs=self.workspace._cfg.num_epochs,
|
|
136
|
+
train_dataloader=self.workspace._train_dataloaders,
|
|
137
|
+
test_dataloader=self.workspace._test_dataloaders,
|
|
138
|
+
optimizer=self.workspace.optimizer,
|
|
139
|
+
scheduler=self.workspace.scheduler,
|
|
140
|
+
is_main_process=is_main_process,
|
|
141
|
+
)
|
|
142
|
+
if begin_epoch_log is not None:
|
|
143
|
+
self.accelerator.log(begin_epoch_log)
|
|
144
|
+
|
|
145
|
+
if self._loss_fn_has_begin_epoch:
|
|
146
|
+
loss_fn = self.accelerator.unwrap_model(self.workspace.loss_fn)
|
|
147
|
+
is_main_process = self.accelerator.is_local_main_process
|
|
148
|
+
begin_epoch_log = loss_fn._begin_epoch(
|
|
149
|
+
epoch=self.workspace._epoch,
|
|
150
|
+
epochs=self.workspace._cfg.num_epochs,
|
|
151
|
+
train_dataloader=self.workspace._train_dataloaders,
|
|
152
|
+
test_dataloader=self.workspace._test_dataloaders,
|
|
153
|
+
optimizer=self.workspace.optimizer,
|
|
154
|
+
scheduler=self.workspace.scheduler,
|
|
155
|
+
is_main_process=is_main_process,
|
|
156
|
+
)
|
|
157
|
+
if begin_epoch_log is not None:
|
|
158
|
+
self.accelerator.log(begin_epoch_log)
|
|
159
|
+
|
|
160
|
+
def begin_batch(self):
|
|
161
|
+
if self._model_has_begin_batch:
|
|
162
|
+
model = self.accelerator.unwrap_model(self.workspace.model)
|
|
163
|
+
is_main_process = self.accelerator.is_local_main_process
|
|
164
|
+
begin_batch_log = model._begin_batch(
|
|
165
|
+
epoch=self.workspace._epoch,
|
|
166
|
+
batch_step=self.batch_step,
|
|
167
|
+
epochs=self.workspace._cfg.num_epochs,
|
|
168
|
+
train_dataloader=self.workspace._train_dataloaders,
|
|
169
|
+
test_dataloader=self.workspace._test_dataloaders,
|
|
170
|
+
optimizer=self.workspace.optimizer,
|
|
171
|
+
scheduler=self.workspace.scheduler,
|
|
172
|
+
is_main_process=is_main_process,
|
|
173
|
+
)
|
|
174
|
+
if begin_batch_log is not None:
|
|
175
|
+
self.accelerator.log(begin_batch_log)
|
|
176
|
+
|
|
177
|
+
if self._loss_fn_has_begin_batch:
|
|
178
|
+
loss_fn = self.accelerator.unwrap_model(self.workspace.loss_fn)
|
|
179
|
+
is_main_process = self.accelerator.is_local_main_process
|
|
180
|
+
begin_batch_log = loss_fn._begin_batch(
|
|
181
|
+
epoch=self.workspace._epoch,
|
|
182
|
+
batch_step=self.batch_step,
|
|
183
|
+
epochs=self.workspace._cfg.num_epochs,
|
|
184
|
+
train_dataloader=self.workspace._train_dataloaders,
|
|
185
|
+
test_dataloader=self.workspace._test_dataloaders,
|
|
186
|
+
optimizer=self.workspace.optimizer,
|
|
187
|
+
scheduler=self.workspace.scheduler,
|
|
188
|
+
is_main_process=is_main_process,
|
|
189
|
+
)
|
|
190
|
+
if begin_batch_log is not None:
|
|
191
|
+
self.accelerator.log(begin_batch_log)
|
|
192
|
+
|
|
193
|
+
self.batch_step += 1
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def log_wandb_image(image, label):
|
|
197
|
+
"""
|
|
198
|
+
image: torch.Tensor or np.ndarray
|
|
199
|
+
label: str, label for the image to categorize it in wandb
|
|
200
|
+
|
|
201
|
+
returns: dict, {label: wandb.Image(image)}
|
|
202
|
+
"""
|
|
203
|
+
import wandb
|
|
204
|
+
|
|
205
|
+
if type(image) == torch.Tensor:
|
|
206
|
+
image = image.cpu().numpy()
|
|
207
|
+
while image.shape[0] == 1:
|
|
208
|
+
image = image.squeeze(0)
|
|
209
|
+
if image.shape[0] == 3:
|
|
210
|
+
image = image.transpose(1, 2, 0)
|
|
211
|
+
image = wandb.Image(image)
|
|
212
|
+
return {label: image}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
# permutation transformation matrix to go from record3d axis to personal camera axis
|
|
4
|
+
P = np.array([[-1, 0, 0, 0], [0, 0, -1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
|
|
5
|
+
# end effector transformation matrix to go from personal camera axis to end effector axis. Corrects the 15 degree offset along the x axis
|
|
6
|
+
EFT = np.asarray([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def apply_permutation_transform(matrix):
|
|
10
|
+
return P @ matrix @ P.T
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def invert_permutation_transform(matrix):
|
|
14
|
+
return P.T @ matrix @ P
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def apply_end_effector_transform(matrix):
|
|
18
|
+
return EFT @ matrix
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def invert_end_effector_transform(matrix):
|
|
22
|
+
return EFT.T @ matrix
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Iterable, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from timm.data.random_erasing import RandomErasing
|
|
6
|
+
from torchvision import transforms
|
|
7
|
+
from torchvision.transforms import InterpolationMode
|
|
8
|
+
|
|
9
|
+
DEFAULT_CROP_PCT = 0.875
|
|
10
|
+
DEFAULT_CROP_MODE = "center"
|
|
11
|
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
|
12
|
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
|
13
|
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
|
14
|
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
|
15
|
+
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
|
16
|
+
IMAGENET_DPN_STD = tuple([1 / (0.0167 * 255)] * 3)
|
|
17
|
+
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
|
18
|
+
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_transform(
|
|
22
|
+
input_size,
|
|
23
|
+
is_training: bool = False,
|
|
24
|
+
scale: Optional[Tuple[float, float]] = None,
|
|
25
|
+
ratio: Optional[Tuple[float, float]] = None,
|
|
26
|
+
hflip: float = 0.5,
|
|
27
|
+
vflip: float = 0.0,
|
|
28
|
+
grayscale: float = 0.0,
|
|
29
|
+
gaussblr: float = 0.0,
|
|
30
|
+
gaussblr_kernel: int = 3,
|
|
31
|
+
gaussblr_sigma: tuple = (1.0, 2.0),
|
|
32
|
+
color_jitter: Optional[Union[Iterable[float], float]] = 0.4,
|
|
33
|
+
interpolation: Union[str, InterpolationMode] = "bilinear",
|
|
34
|
+
mean: Union[Iterable, torch.Tensor] = IMAGENET_DEFAULT_MEAN,
|
|
35
|
+
std: Union[Iterable, torch.Tensor] = IMAGENET_DEFAULT_STD,
|
|
36
|
+
re_prob=0.0,
|
|
37
|
+
re_mode="const",
|
|
38
|
+
re_count=1,
|
|
39
|
+
re_num_splits=0,
|
|
40
|
+
crop_pct=None,
|
|
41
|
+
crop_mode=None,
|
|
42
|
+
*args,
|
|
43
|
+
**kwargs,
|
|
44
|
+
):
|
|
45
|
+
if isinstance(input_size, (tuple, list)):
|
|
46
|
+
img_size = input_size[-2:]
|
|
47
|
+
else:
|
|
48
|
+
img_size = input_size
|
|
49
|
+
transform_list = []
|
|
50
|
+
|
|
51
|
+
interpolation_mode = InterpolationMode.BILINEAR
|
|
52
|
+
try:
|
|
53
|
+
interpolation_mode = InterpolationMode[interpolation.upper()]
|
|
54
|
+
except AttributeError:
|
|
55
|
+
logging.warning(
|
|
56
|
+
f"Interpolation mode {interpolation} is not recognized, "
|
|
57
|
+
f"using bilinear instead."
|
|
58
|
+
)
|
|
59
|
+
if is_training:
|
|
60
|
+
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
|
61
|
+
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
|
|
62
|
+
|
|
63
|
+
transform_list.append(
|
|
64
|
+
transforms.RandomResizedCrop(
|
|
65
|
+
img_size,
|
|
66
|
+
scale,
|
|
67
|
+
ratio,
|
|
68
|
+
antialias=False,
|
|
69
|
+
interpolation=interpolation_mode,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
if hflip > 0.0:
|
|
73
|
+
transform_list.append(transforms.RandomHorizontalFlip(hflip))
|
|
74
|
+
if vflip > 0.0:
|
|
75
|
+
transform_list.append(transforms.RandomVerticalFlip(vflip))
|
|
76
|
+
|
|
77
|
+
if color_jitter is not None:
|
|
78
|
+
if isinstance(color_jitter, (list, tuple)):
|
|
79
|
+
assert len(color_jitter) in (
|
|
80
|
+
3,
|
|
81
|
+
4,
|
|
82
|
+
), "expected either 3 or 4 values for color jitter"
|
|
83
|
+
else:
|
|
84
|
+
color_jitter = (float(color_jitter),) * 3
|
|
85
|
+
transform_list.append(transforms.ColorJitter(*color_jitter))
|
|
86
|
+
|
|
87
|
+
if grayscale > 0.0:
|
|
88
|
+
transform_list.append(transforms.RandomGrayscale(p=grayscale))
|
|
89
|
+
if gaussblr > 0.0:
|
|
90
|
+
transform_list.append(
|
|
91
|
+
transforms.RandomApply(
|
|
92
|
+
[
|
|
93
|
+
transforms.GaussianBlur(
|
|
94
|
+
(gaussblr_kernel, gaussblr_kernel), gaussblr_sigma
|
|
95
|
+
)
|
|
96
|
+
],
|
|
97
|
+
p=gaussblr,
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
transform_list.append(transforms.Normalize(mean, std))
|
|
101
|
+
|
|
102
|
+
if re_prob > 0.0:
|
|
103
|
+
transform_list.append(
|
|
104
|
+
RandomErasing(
|
|
105
|
+
re_prob,
|
|
106
|
+
mode=re_mode,
|
|
107
|
+
max_count=re_count,
|
|
108
|
+
num_splits=re_num_splits,
|
|
109
|
+
device="cuda",
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
else:
|
|
114
|
+
if crop_pct is None:
|
|
115
|
+
crop_pct = DEFAULT_CROP_PCT
|
|
116
|
+
if crop_mode is None:
|
|
117
|
+
crop_mode = DEFAULT_CROP_MODE
|
|
118
|
+
rescaled_size = (
|
|
119
|
+
int(img_size / crop_pct)
|
|
120
|
+
if isinstance(img_size, (int, float))
|
|
121
|
+
else ([int(size / crop_pct) for size in img_size])
|
|
122
|
+
)
|
|
123
|
+
# TODO (haritheja): we are no longer cropping, but figure this out
|
|
124
|
+
# transform_list.append(
|
|
125
|
+
# transforms.Resize(rescaled_size, interpolation_mode, antialias=False)
|
|
126
|
+
# )
|
|
127
|
+
# if crop_mode == "center":
|
|
128
|
+
# transform_list.append(transforms.CenterCrop(img_size))
|
|
129
|
+
# elif crop_mode == "random":
|
|
130
|
+
# transform_list.append(transforms.RandomCrop(img_size))
|
|
131
|
+
# else:
|
|
132
|
+
# raise ValueError(f"crop_mode '{crop_mode}' not recognized")
|
|
133
|
+
transform_list.append(transforms.Normalize(mean, std))
|
|
134
|
+
|
|
135
|
+
return transforms.Compose(transform_list)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import zmq
|
|
2
|
+
import time
|
|
3
|
+
import pickle
|
|
4
|
+
import threading
|
|
5
|
+
import traceback
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RPCServer:
|
|
9
|
+
def __init__(self, obj, port: int = 5000, threaded=False):
|
|
10
|
+
"""
|
|
11
|
+
obj: object with methods to expose
|
|
12
|
+
port: port to listen on
|
|
13
|
+
"""
|
|
14
|
+
self.obj = obj
|
|
15
|
+
self.context = zmq.Context()
|
|
16
|
+
self.socket: zmq.socket.Socket = self.context.socket(zmq.REP)
|
|
17
|
+
self.socket.bind(f"tcp://*:{port}")
|
|
18
|
+
self.threaded = threaded
|
|
19
|
+
if threaded:
|
|
20
|
+
self.thread = threading.Thread(target=self.run)
|
|
21
|
+
self.stop_event = threading.Event()
|
|
22
|
+
else:
|
|
23
|
+
self.stop_event = False
|
|
24
|
+
|
|
25
|
+
def _send_exception(self, e):
|
|
26
|
+
"""
|
|
27
|
+
Serialize an exception and send it over the socket.
|
|
28
|
+
Only the exception type, message, and traceback are sent.
|
|
29
|
+
"""
|
|
30
|
+
exception = {
|
|
31
|
+
"type": "exception",
|
|
32
|
+
"content": {
|
|
33
|
+
"exception": str(type(e)),
|
|
34
|
+
"message": str(e),
|
|
35
|
+
"traceback": traceback.format_exc(),
|
|
36
|
+
},
|
|
37
|
+
}
|
|
38
|
+
self.socket.send(pickle.dumps(exception))
|
|
39
|
+
|
|
40
|
+
def _send_result(self, result):
|
|
41
|
+
"""
|
|
42
|
+
Serialize a result and send it over the socket.
|
|
43
|
+
"""
|
|
44
|
+
result = {"type": "result", "content": result}
|
|
45
|
+
self.socket.send(pickle.dumps(result))
|
|
46
|
+
|
|
47
|
+
def run(self):
|
|
48
|
+
"""
|
|
49
|
+
Run the server.
|
|
50
|
+
"""
|
|
51
|
+
if self.threaded:
|
|
52
|
+
while not self.stop_event.is_set():
|
|
53
|
+
message = self.socket.recv()
|
|
54
|
+
message = pickle.loads(message)
|
|
55
|
+
self._handle_message(message)
|
|
56
|
+
else:
|
|
57
|
+
while not self.stop_event:
|
|
58
|
+
try:
|
|
59
|
+
message = self.socket.recv(flags=zmq.NOBLOCK)
|
|
60
|
+
message = pickle.loads(message)
|
|
61
|
+
except zmq.Again:
|
|
62
|
+
time.sleep(0.001)
|
|
63
|
+
continue
|
|
64
|
+
self._handle_message(message)
|
|
65
|
+
|
|
66
|
+
def _is_callable(self, attr):
|
|
67
|
+
return hasattr(self.obj, attr) and callable(getattr(self.obj, attr))
|
|
68
|
+
|
|
69
|
+
def _handle_message(self, message):
|
|
70
|
+
"""
|
|
71
|
+
Handles a dictionary of {
|
|
72
|
+
"req": str, # request type
|
|
73
|
+
"attr": str,
|
|
74
|
+
"args": list,
|
|
75
|
+
"kwargs": dict,
|
|
76
|
+
}
|
|
77
|
+
from the socket.
|
|
78
|
+
If req == "is_callable", return whether the attribute is callable.
|
|
79
|
+
If req == "get", return the attribute.
|
|
80
|
+
If the attribute is not found, return an error message.
|
|
81
|
+
If the attribute is callable, call with args and kwargs.
|
|
82
|
+
If there are any errors in the callable, return the pickled error
|
|
83
|
+
If the callable is found and there are no errors, return the pickled result.
|
|
84
|
+
If the attribute is not callable, return the attribute.
|
|
85
|
+
If req == "set", set the attribute to the value.
|
|
86
|
+
If req == "dir", return a list of attributes.
|
|
87
|
+
If req == "stop", stop the server.
|
|
88
|
+
"""
|
|
89
|
+
if message["req"] == "is_callable":
|
|
90
|
+
result = self._is_callable(message["attr"])
|
|
91
|
+
self._send_result(result)
|
|
92
|
+
elif message["req"] == "get":
|
|
93
|
+
try:
|
|
94
|
+
attribute = getattr(self.obj, message["attr"])
|
|
95
|
+
args = message["args"]
|
|
96
|
+
kwargs = message["kwargs"]
|
|
97
|
+
if not callable(attribute):
|
|
98
|
+
self._send_result(attribute)
|
|
99
|
+
else:
|
|
100
|
+
result = attribute(*args, **kwargs)
|
|
101
|
+
self._send_result(result)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
self._send_exception(e)
|
|
104
|
+
elif message["req"] == "set":
|
|
105
|
+
try:
|
|
106
|
+
setattr(self.obj, message["attr"], message["value"])
|
|
107
|
+
self._send_result(None)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
self._send_exception(e)
|
|
110
|
+
elif message["req"] == "dir":
|
|
111
|
+
result = dir(self.obj)
|
|
112
|
+
self._send_result(result)
|
|
113
|
+
elif message["req"] == "stop":
|
|
114
|
+
self.stop()
|
|
115
|
+
|
|
116
|
+
def close(self):
|
|
117
|
+
self.socket.close()
|
|
118
|
+
self.context.term()
|
|
119
|
+
|
|
120
|
+
def start(self):
|
|
121
|
+
if self.threaded:
|
|
122
|
+
self.stop_event.clear()
|
|
123
|
+
self.thread.start()
|
|
124
|
+
else:
|
|
125
|
+
self.run()
|
|
126
|
+
|
|
127
|
+
def stop(self):
|
|
128
|
+
if self.threaded:
|
|
129
|
+
self.stop_event.set()
|
|
130
|
+
self.thread.join()
|
|
131
|
+
else:
|
|
132
|
+
self.stop_event = True
|
|
133
|
+
self.close()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class RPCException(Exception):
|
|
137
|
+
def __init__(self, exception_type: str, message: str, traceback: str):
|
|
138
|
+
self.exception_type = exception_type
|
|
139
|
+
self.message = message
|
|
140
|
+
self.traceback = traceback
|
|
141
|
+
|
|
142
|
+
def __str__(self):
|
|
143
|
+
return f"{self.exception_type}: {self.message}\n{self.traceback}"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class RPCClient:
|
|
147
|
+
def __init__(self, host: str, port: int = 5000):
|
|
148
|
+
"""
|
|
149
|
+
host: host to connect to
|
|
150
|
+
port: port to connect to
|
|
151
|
+
"""
|
|
152
|
+
self.__dict__["context"] = zmq.Context()
|
|
153
|
+
self.__dict__["socket"] = self.context.socket(zmq.REQ)
|
|
154
|
+
self.socket.connect(f"tcp://{host}:{port}")
|
|
155
|
+
self.__dict__["_is_callable_cache"] = {}
|
|
156
|
+
|
|
157
|
+
def __setattr__(self, attr: str, value):
|
|
158
|
+
"""
|
|
159
|
+
Set the attribute of the same name.
|
|
160
|
+
Attribute must not be callable on remote.
|
|
161
|
+
"""
|
|
162
|
+
if self._is_callable(attr):
|
|
163
|
+
raise AttributeError(f"Overwriting a callable attribute: {attr}")
|
|
164
|
+
self._send_set(attr, value)
|
|
165
|
+
|
|
166
|
+
def _send_get(self, attr: str, args: list, kwargs: dict):
|
|
167
|
+
"""
|
|
168
|
+
Send a get request over the socket.
|
|
169
|
+
"""
|
|
170
|
+
req = {"req": "get", "attr": attr, "args": args, "kwargs": kwargs}
|
|
171
|
+
self.socket.send(pickle.dumps(req))
|
|
172
|
+
return self._recv_result()
|
|
173
|
+
|
|
174
|
+
def _send_set(self, attr: str, value):
|
|
175
|
+
"""
|
|
176
|
+
Send a set request over the socket.
|
|
177
|
+
"""
|
|
178
|
+
req = {"req": "set", "attr": attr, "value": value}
|
|
179
|
+
self.socket.send(pickle.dumps(req))
|
|
180
|
+
return self._recv_result()
|
|
181
|
+
|
|
182
|
+
def _recv_result(self):
|
|
183
|
+
"""
|
|
184
|
+
Receive a dictionary of {
|
|
185
|
+
"type": str,
|
|
186
|
+
"content": object,
|
|
187
|
+
}
|
|
188
|
+
if type == "exception", content is a dictionary of {
|
|
189
|
+
"exception": str,
|
|
190
|
+
"message": str,
|
|
191
|
+
"traceback": str,
|
|
192
|
+
}; re-raise the exception on the client side
|
|
193
|
+
if type == "result", content is the result
|
|
194
|
+
"""
|
|
195
|
+
result = self.socket.recv()
|
|
196
|
+
result = pickle.loads(result)
|
|
197
|
+
if result["type"] == "exception":
|
|
198
|
+
raise RPCException(
|
|
199
|
+
result["content"]["exception"],
|
|
200
|
+
result["content"]["message"],
|
|
201
|
+
result["content"]["traceback"],
|
|
202
|
+
)
|
|
203
|
+
return result["content"]
|
|
204
|
+
|
|
205
|
+
def _is_callable(self, attr: str) -> bool:
|
|
206
|
+
"""
|
|
207
|
+
Send a request to check if the attribute is callable.
|
|
208
|
+
Returns False if the attribute is not found.
|
|
209
|
+
"""
|
|
210
|
+
if attr not in self._is_callable_cache:
|
|
211
|
+
req = {"req": "is_callable", "attr": attr}
|
|
212
|
+
self.socket.send(pickle.dumps(req))
|
|
213
|
+
result = self._recv_result()
|
|
214
|
+
self._is_callable_cache[attr] = result
|
|
215
|
+
return self._is_callable_cache[attr]
|
|
216
|
+
|
|
217
|
+
def __getattr__(self, attr: str):
|
|
218
|
+
"""
|
|
219
|
+
Return the attribute of the same name.
|
|
220
|
+
If the attribute is a callable, return a function that sends the call over the socket.
|
|
221
|
+
Else, return the attribute value.
|
|
222
|
+
"""
|
|
223
|
+
if self._is_callable(attr):
|
|
224
|
+
return lambda *args, **kwargs: self._send_get(attr, args, kwargs)
|
|
225
|
+
else:
|
|
226
|
+
return self._send_get(attr, [], {})
|
|
227
|
+
|
|
228
|
+
def __dir__(self):
|
|
229
|
+
"""
|
|
230
|
+
Return a list of attributes.
|
|
231
|
+
"""
|
|
232
|
+
req = {"req": "dir"}
|
|
233
|
+
self.socket.send(pickle.dumps(req))
|
|
234
|
+
result = self._recv_result()
|
|
235
|
+
return result + ["stop_server"]
|
|
236
|
+
|
|
237
|
+
def stop_server(self) -> bool:
|
|
238
|
+
"""
|
|
239
|
+
Send a stop request to the server.
|
|
240
|
+
If the server is stopped, close the socket and terminate the context.
|
|
241
|
+
Returns a bool for success.
|
|
242
|
+
"""
|
|
243
|
+
req = {"req": "stop"}
|
|
244
|
+
self.socket.send(pickle.dumps(req))
|
|
245
|
+
stopped = self._recv_result()
|
|
246
|
+
if stopped:
|
|
247
|
+
self.socket.close()
|
|
248
|
+
self.context.term()
|
|
249
|
+
return stopped
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CosineAnnealWithWarmupLR(LRScheduler):
|
|
8
|
+
"""Sets the learning rate of each parameter group to the initial lr
|
|
9
|
+
times a given function. When last_epoch=-1, sets initial lr as lr.
|
|
10
|
+
|
|
11
|
+
Adapted from karpathy/nanoGPT
|
|
12
|
+
https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/train.py#L228
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
optimizer (Optimizer): Wrapped optimizer.
|
|
16
|
+
warmup_epochs (int): The number of epochs for warmup
|
|
17
|
+
lr_decay_epoch (int): The index of lr_decay epoch, at which we reach min_LR. Default: -1.
|
|
18
|
+
verbose (bool): If ``True``, prints a message to stdout for
|
|
19
|
+
each update. Default: ``False``.
|
|
20
|
+
min_lr_multiplier: The lowest lr multiplier the schedule will ever get to.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> warmup_epochs = 10
|
|
24
|
+
>>> lr_decay_epochs = 600
|
|
25
|
+
>>> scheduler = CosineAnnealWithWarmupLR(optimizer, warmup_epochs, lr_decay_epochs)
|
|
26
|
+
>>> for epoch in range(100):
|
|
27
|
+
>>> train(...)
|
|
28
|
+
>>> validate(...)
|
|
29
|
+
>>> scheduler.step()
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
optimizer,
|
|
35
|
+
warmup_epochs,
|
|
36
|
+
lr_decay_epochs,
|
|
37
|
+
last_epoch=-1,
|
|
38
|
+
min_lr_multiplier=0.1,
|
|
39
|
+
verbose=False,
|
|
40
|
+
):
|
|
41
|
+
self.optimizer = optimizer
|
|
42
|
+
self.warmup_epochs = warmup_epochs
|
|
43
|
+
self.last_epoch = last_epoch
|
|
44
|
+
self.lr_decay_epoch = lr_decay_epochs
|
|
45
|
+
self.min_lr_multiplier = min_lr_multiplier
|
|
46
|
+
|
|
47
|
+
super().__init__(optimizer, last_epoch, verbose)
|
|
48
|
+
|
|
49
|
+
def get_lr(self):
|
|
50
|
+
if not self._get_lr_called_within_step:
|
|
51
|
+
warnings.warn(
|
|
52
|
+
"To get the last learning rate computed by the scheduler, "
|
|
53
|
+
"please use `get_last_lr()`."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return [
|
|
57
|
+
base_lr * self._calculate_lr_formula(self.last_epoch)
|
|
58
|
+
for base_lr in self.base_lrs
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
def _calculate_lr_formula(self, last_epoch):
|
|
62
|
+
if last_epoch < self.warmup_epochs:
|
|
63
|
+
return (last_epoch + 1) / self.warmup_epochs
|
|
64
|
+
elif last_epoch > self.lr_decay_epoch:
|
|
65
|
+
return self.min_lr_multiplier
|
|
66
|
+
decay_ratio = (last_epoch - self.warmup_epochs) / (
|
|
67
|
+
self.lr_decay_epoch - self.warmup_epochs
|
|
68
|
+
)
|
|
69
|
+
assert 0 <= decay_ratio <= 1
|
|
70
|
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
|
|
71
|
+
return self.min_lr_multiplier + coeff * (1 - self.min_lr_multiplier)
|