egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,212 @@
1
+ import random
2
+ from collections import OrderedDict
3
+ from datetime import timedelta
4
+
5
+ import accelerate
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+
10
+ from . import *
11
+
12
+ # import accelerate
13
+
14
+
15
+ def set_seed_everywhere(seed):
16
+ torch.manual_seed(seed)
17
+ if torch.cuda.is_available():
18
+ torch.cuda.manual_seed_all(seed)
19
+ np.random.seed(seed)
20
+ random.seed(seed)
21
+
22
+
23
+ def get_demos(cfg):
24
+ train_files_list = cfg.dataset.train.config.trajectory_roots.trajectory_root_path
25
+ val_files_list = cfg.dataset.test.config.trajectory_roots.trajectory_root_path
26
+
27
+ train_demos = json.load(open(train_files_list))
28
+ val_demos = json.load(open(val_files_list))
29
+
30
+ return {"train": train_demos, "val": val_demos}
31
+
32
+
33
+ def setup_accelerate(experiment, num_epochs, task_title, save_every):
34
+ ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)
35
+
36
+ accelerator = accelerate.Accelerator(
37
+ log_with=["wandb"],
38
+ kwargs_handlers=[
39
+ accelerate.InitProcessGroupKwargs(timeout=timedelta(hours=1.5)),
40
+ ddp_kwargs,
41
+ ],
42
+ )
43
+ accelerator.init_trackers(
44
+ "experiment",
45
+ config={
46
+ "num_epochs": 1000,
47
+ "task_title": task_title,
48
+ "save_every": 20,
49
+ },
50
+ )
51
+
52
+ return accelerator, ddp_kwargs
53
+
54
+
55
+ class AverageMeter:
56
+ def __init__(self, value=None):
57
+ super().__init__()
58
+ if value is None:
59
+ self._sum = 0
60
+ self._count = 0
61
+ else:
62
+ self._sum = value
63
+ self._count = 1
64
+
65
+ def update(self, value, n=1):
66
+ self._sum += value * n
67
+ self._count += n
68
+ self._avg = self._sum / self._count
69
+
70
+ @property
71
+ def avg(self):
72
+ return self._avg
73
+
74
+
75
+ class AveragingDict(OrderedDict):
76
+ def __init__(self, name="train", *args, **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.name = name
79
+
80
+ def update(self, other):
81
+ for k, v in other.items():
82
+ if k not in self:
83
+ self[k] = AverageMeter(v)
84
+ self[k].update(v)
85
+
86
+ def __str__(self):
87
+ return ", ".join(f"{self.name}/{k}: {v.avg:.2E}" for k, v in self.items())
88
+
89
+ @property
90
+ def full_summary(self):
91
+ return {f"{self.name}/{k}": v.avg for k, v in self.items()}
92
+
93
+ @property
94
+ def summary(self):
95
+ return {f"{self.name}/{k}": f"{v.avg:.2E}" for k, v in self.items()}
96
+
97
+
98
+ class Callbacks:
99
+ """
100
+ This class is used to implement callback functions inside training framework that
101
+ passes, training params to callback functions inside model and loss objects.
102
+
103
+ This is useful for custom logging inside models and losses like adding image
104
+ visualisation, plotting hyperparameters like learning rate, etc.
105
+ """
106
+
107
+ def __init__(self):
108
+ pass
109
+
110
+ def set_workspace(self, workspace, accelerator):
111
+ self.workspace = workspace
112
+ self.accelerator = accelerator
113
+ self._model_has_begin_epoch = hasattr(
114
+ self.accelerator.unwrap_model(self.workspace.model), "_begin_epoch"
115
+ )
116
+ self._model_has_begin_batch = hasattr(
117
+ self.accelerator.unwrap_model(self.workspace.model), "_begin_batch"
118
+ )
119
+
120
+ self._loss_fn_has_begin_epoch = hasattr(
121
+ self.accelerator.unwrap_model(self.workspace.loss_fn), "_begin_epoch"
122
+ )
123
+ self._loss_fn_has_begin_batch = hasattr(
124
+ self.accelerator.unwrap_model(self.workspace.loss_fn), "_begin_batch"
125
+ )
126
+
127
+ def begin_epoch(self):
128
+ self.batch_step = 0
129
+
130
+ if self._model_has_begin_epoch:
131
+ model = self.accelerator.unwrap_model(self.workspace.model)
132
+ is_main_process = self.accelerator.is_local_main_process
133
+ begin_epoch_log = model._begin_epoch(
134
+ epoch=self.workspace._epoch,
135
+ epochs=self.workspace._cfg.num_epochs,
136
+ train_dataloader=self.workspace._train_dataloaders,
137
+ test_dataloader=self.workspace._test_dataloaders,
138
+ optimizer=self.workspace.optimizer,
139
+ scheduler=self.workspace.scheduler,
140
+ is_main_process=is_main_process,
141
+ )
142
+ if begin_epoch_log is not None:
143
+ self.accelerator.log(begin_epoch_log)
144
+
145
+ if self._loss_fn_has_begin_epoch:
146
+ loss_fn = self.accelerator.unwrap_model(self.workspace.loss_fn)
147
+ is_main_process = self.accelerator.is_local_main_process
148
+ begin_epoch_log = loss_fn._begin_epoch(
149
+ epoch=self.workspace._epoch,
150
+ epochs=self.workspace._cfg.num_epochs,
151
+ train_dataloader=self.workspace._train_dataloaders,
152
+ test_dataloader=self.workspace._test_dataloaders,
153
+ optimizer=self.workspace.optimizer,
154
+ scheduler=self.workspace.scheduler,
155
+ is_main_process=is_main_process,
156
+ )
157
+ if begin_epoch_log is not None:
158
+ self.accelerator.log(begin_epoch_log)
159
+
160
+ def begin_batch(self):
161
+ if self._model_has_begin_batch:
162
+ model = self.accelerator.unwrap_model(self.workspace.model)
163
+ is_main_process = self.accelerator.is_local_main_process
164
+ begin_batch_log = model._begin_batch(
165
+ epoch=self.workspace._epoch,
166
+ batch_step=self.batch_step,
167
+ epochs=self.workspace._cfg.num_epochs,
168
+ train_dataloader=self.workspace._train_dataloaders,
169
+ test_dataloader=self.workspace._test_dataloaders,
170
+ optimizer=self.workspace.optimizer,
171
+ scheduler=self.workspace.scheduler,
172
+ is_main_process=is_main_process,
173
+ )
174
+ if begin_batch_log is not None:
175
+ self.accelerator.log(begin_batch_log)
176
+
177
+ if self._loss_fn_has_begin_batch:
178
+ loss_fn = self.accelerator.unwrap_model(self.workspace.loss_fn)
179
+ is_main_process = self.accelerator.is_local_main_process
180
+ begin_batch_log = loss_fn._begin_batch(
181
+ epoch=self.workspace._epoch,
182
+ batch_step=self.batch_step,
183
+ epochs=self.workspace._cfg.num_epochs,
184
+ train_dataloader=self.workspace._train_dataloaders,
185
+ test_dataloader=self.workspace._test_dataloaders,
186
+ optimizer=self.workspace.optimizer,
187
+ scheduler=self.workspace.scheduler,
188
+ is_main_process=is_main_process,
189
+ )
190
+ if begin_batch_log is not None:
191
+ self.accelerator.log(begin_batch_log)
192
+
193
+ self.batch_step += 1
194
+
195
+
196
+ def log_wandb_image(image, label):
197
+ """
198
+ image: torch.Tensor or np.ndarray
199
+ label: str, label for the image to categorize it in wandb
200
+
201
+ returns: dict, {label: wandb.Image(image)}
202
+ """
203
+ import wandb
204
+
205
+ if type(image) == torch.Tensor:
206
+ image = image.cpu().numpy()
207
+ while image.shape[0] == 1:
208
+ image = image.squeeze(0)
209
+ if image.shape[0] == 3:
210
+ image = image.transpose(1, 2, 0)
211
+ image = wandb.Image(image)
212
+ return {label: image}
@@ -0,0 +1,22 @@
1
+ import numpy as np
2
+
3
+ # permutation transformation matrix to go from record3d axis to personal camera axis
4
+ P = np.array([[-1, 0, 0, 0], [0, 0, -1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
5
+ # end effector transformation matrix to go from personal camera axis to end effector axis. Corrects the 15 degree offset along the x axis
6
+ EFT = np.asarray([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
7
+
8
+
9
+ def apply_permutation_transform(matrix):
10
+ return P @ matrix @ P.T
11
+
12
+
13
+ def invert_permutation_transform(matrix):
14
+ return P.T @ matrix @ P
15
+
16
+
17
+ def apply_end_effector_transform(matrix):
18
+ return EFT @ matrix
19
+
20
+
21
+ def invert_end_effector_transform(matrix):
22
+ return EFT.T @ matrix
@@ -0,0 +1,135 @@
1
+ import logging
2
+ from typing import Iterable, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from timm.data.random_erasing import RandomErasing
6
+ from torchvision import transforms
7
+ from torchvision.transforms import InterpolationMode
8
+
9
+ DEFAULT_CROP_PCT = 0.875
10
+ DEFAULT_CROP_MODE = "center"
11
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
12
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
13
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
14
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
15
+ IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
16
+ IMAGENET_DPN_STD = tuple([1 / (0.0167 * 255)] * 3)
17
+ OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
18
+ OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
19
+
20
+
21
+ def create_transform(
22
+ input_size,
23
+ is_training: bool = False,
24
+ scale: Optional[Tuple[float, float]] = None,
25
+ ratio: Optional[Tuple[float, float]] = None,
26
+ hflip: float = 0.5,
27
+ vflip: float = 0.0,
28
+ grayscale: float = 0.0,
29
+ gaussblr: float = 0.0,
30
+ gaussblr_kernel: int = 3,
31
+ gaussblr_sigma: tuple = (1.0, 2.0),
32
+ color_jitter: Optional[Union[Iterable[float], float]] = 0.4,
33
+ interpolation: Union[str, InterpolationMode] = "bilinear",
34
+ mean: Union[Iterable, torch.Tensor] = IMAGENET_DEFAULT_MEAN,
35
+ std: Union[Iterable, torch.Tensor] = IMAGENET_DEFAULT_STD,
36
+ re_prob=0.0,
37
+ re_mode="const",
38
+ re_count=1,
39
+ re_num_splits=0,
40
+ crop_pct=None,
41
+ crop_mode=None,
42
+ *args,
43
+ **kwargs,
44
+ ):
45
+ if isinstance(input_size, (tuple, list)):
46
+ img_size = input_size[-2:]
47
+ else:
48
+ img_size = input_size
49
+ transform_list = []
50
+
51
+ interpolation_mode = InterpolationMode.BILINEAR
52
+ try:
53
+ interpolation_mode = InterpolationMode[interpolation.upper()]
54
+ except AttributeError:
55
+ logging.warning(
56
+ f"Interpolation mode {interpolation} is not recognized, "
57
+ f"using bilinear instead."
58
+ )
59
+ if is_training:
60
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
61
+ ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
62
+
63
+ transform_list.append(
64
+ transforms.RandomResizedCrop(
65
+ img_size,
66
+ scale,
67
+ ratio,
68
+ antialias=False,
69
+ interpolation=interpolation_mode,
70
+ )
71
+ )
72
+ if hflip > 0.0:
73
+ transform_list.append(transforms.RandomHorizontalFlip(hflip))
74
+ if vflip > 0.0:
75
+ transform_list.append(transforms.RandomVerticalFlip(vflip))
76
+
77
+ if color_jitter is not None:
78
+ if isinstance(color_jitter, (list, tuple)):
79
+ assert len(color_jitter) in (
80
+ 3,
81
+ 4,
82
+ ), "expected either 3 or 4 values for color jitter"
83
+ else:
84
+ color_jitter = (float(color_jitter),) * 3
85
+ transform_list.append(transforms.ColorJitter(*color_jitter))
86
+
87
+ if grayscale > 0.0:
88
+ transform_list.append(transforms.RandomGrayscale(p=grayscale))
89
+ if gaussblr > 0.0:
90
+ transform_list.append(
91
+ transforms.RandomApply(
92
+ [
93
+ transforms.GaussianBlur(
94
+ (gaussblr_kernel, gaussblr_kernel), gaussblr_sigma
95
+ )
96
+ ],
97
+ p=gaussblr,
98
+ )
99
+ )
100
+ transform_list.append(transforms.Normalize(mean, std))
101
+
102
+ if re_prob > 0.0:
103
+ transform_list.append(
104
+ RandomErasing(
105
+ re_prob,
106
+ mode=re_mode,
107
+ max_count=re_count,
108
+ num_splits=re_num_splits,
109
+ device="cuda",
110
+ )
111
+ )
112
+
113
+ else:
114
+ if crop_pct is None:
115
+ crop_pct = DEFAULT_CROP_PCT
116
+ if crop_mode is None:
117
+ crop_mode = DEFAULT_CROP_MODE
118
+ rescaled_size = (
119
+ int(img_size / crop_pct)
120
+ if isinstance(img_size, (int, float))
121
+ else ([int(size / crop_pct) for size in img_size])
122
+ )
123
+ # TODO (haritheja): we are no longer cropping, but figure this out
124
+ # transform_list.append(
125
+ # transforms.Resize(rescaled_size, interpolation_mode, antialias=False)
126
+ # )
127
+ # if crop_mode == "center":
128
+ # transform_list.append(transforms.CenterCrop(img_size))
129
+ # elif crop_mode == "random":
130
+ # transform_list.append(transforms.RandomCrop(img_size))
131
+ # else:
132
+ # raise ValueError(f"crop_mode '{crop_mode}' not recognized")
133
+ transform_list.append(transforms.Normalize(mean, std))
134
+
135
+ return transforms.Compose(transform_list)
@@ -0,0 +1,249 @@
1
+ import zmq
2
+ import time
3
+ import pickle
4
+ import threading
5
+ import traceback
6
+
7
+
8
+ class RPCServer:
9
+ def __init__(self, obj, port: int = 5000, threaded=False):
10
+ """
11
+ obj: object with methods to expose
12
+ port: port to listen on
13
+ """
14
+ self.obj = obj
15
+ self.context = zmq.Context()
16
+ self.socket: zmq.socket.Socket = self.context.socket(zmq.REP)
17
+ self.socket.bind(f"tcp://*:{port}")
18
+ self.threaded = threaded
19
+ if threaded:
20
+ self.thread = threading.Thread(target=self.run)
21
+ self.stop_event = threading.Event()
22
+ else:
23
+ self.stop_event = False
24
+
25
+ def _send_exception(self, e):
26
+ """
27
+ Serialize an exception and send it over the socket.
28
+ Only the exception type, message, and traceback are sent.
29
+ """
30
+ exception = {
31
+ "type": "exception",
32
+ "content": {
33
+ "exception": str(type(e)),
34
+ "message": str(e),
35
+ "traceback": traceback.format_exc(),
36
+ },
37
+ }
38
+ self.socket.send(pickle.dumps(exception))
39
+
40
+ def _send_result(self, result):
41
+ """
42
+ Serialize a result and send it over the socket.
43
+ """
44
+ result = {"type": "result", "content": result}
45
+ self.socket.send(pickle.dumps(result))
46
+
47
+ def run(self):
48
+ """
49
+ Run the server.
50
+ """
51
+ if self.threaded:
52
+ while not self.stop_event.is_set():
53
+ message = self.socket.recv()
54
+ message = pickle.loads(message)
55
+ self._handle_message(message)
56
+ else:
57
+ while not self.stop_event:
58
+ try:
59
+ message = self.socket.recv(flags=zmq.NOBLOCK)
60
+ message = pickle.loads(message)
61
+ except zmq.Again:
62
+ time.sleep(0.001)
63
+ continue
64
+ self._handle_message(message)
65
+
66
+ def _is_callable(self, attr):
67
+ return hasattr(self.obj, attr) and callable(getattr(self.obj, attr))
68
+
69
+ def _handle_message(self, message):
70
+ """
71
+ Handles a dictionary of {
72
+ "req": str, # request type
73
+ "attr": str,
74
+ "args": list,
75
+ "kwargs": dict,
76
+ }
77
+ from the socket.
78
+ If req == "is_callable", return whether the attribute is callable.
79
+ If req == "get", return the attribute.
80
+ If the attribute is not found, return an error message.
81
+ If the attribute is callable, call with args and kwargs.
82
+ If there are any errors in the callable, return the pickled error
83
+ If the callable is found and there are no errors, return the pickled result.
84
+ If the attribute is not callable, return the attribute.
85
+ If req == "set", set the attribute to the value.
86
+ If req == "dir", return a list of attributes.
87
+ If req == "stop", stop the server.
88
+ """
89
+ if message["req"] == "is_callable":
90
+ result = self._is_callable(message["attr"])
91
+ self._send_result(result)
92
+ elif message["req"] == "get":
93
+ try:
94
+ attribute = getattr(self.obj, message["attr"])
95
+ args = message["args"]
96
+ kwargs = message["kwargs"]
97
+ if not callable(attribute):
98
+ self._send_result(attribute)
99
+ else:
100
+ result = attribute(*args, **kwargs)
101
+ self._send_result(result)
102
+ except Exception as e:
103
+ self._send_exception(e)
104
+ elif message["req"] == "set":
105
+ try:
106
+ setattr(self.obj, message["attr"], message["value"])
107
+ self._send_result(None)
108
+ except Exception as e:
109
+ self._send_exception(e)
110
+ elif message["req"] == "dir":
111
+ result = dir(self.obj)
112
+ self._send_result(result)
113
+ elif message["req"] == "stop":
114
+ self.stop()
115
+
116
+ def close(self):
117
+ self.socket.close()
118
+ self.context.term()
119
+
120
+ def start(self):
121
+ if self.threaded:
122
+ self.stop_event.clear()
123
+ self.thread.start()
124
+ else:
125
+ self.run()
126
+
127
+ def stop(self):
128
+ if self.threaded:
129
+ self.stop_event.set()
130
+ self.thread.join()
131
+ else:
132
+ self.stop_event = True
133
+ self.close()
134
+
135
+
136
+ class RPCException(Exception):
137
+ def __init__(self, exception_type: str, message: str, traceback: str):
138
+ self.exception_type = exception_type
139
+ self.message = message
140
+ self.traceback = traceback
141
+
142
+ def __str__(self):
143
+ return f"{self.exception_type}: {self.message}\n{self.traceback}"
144
+
145
+
146
+ class RPCClient:
147
+ def __init__(self, host: str, port: int = 5000):
148
+ """
149
+ host: host to connect to
150
+ port: port to connect to
151
+ """
152
+ self.__dict__["context"] = zmq.Context()
153
+ self.__dict__["socket"] = self.context.socket(zmq.REQ)
154
+ self.socket.connect(f"tcp://{host}:{port}")
155
+ self.__dict__["_is_callable_cache"] = {}
156
+
157
+ def __setattr__(self, attr: str, value):
158
+ """
159
+ Set the attribute of the same name.
160
+ Attribute must not be callable on remote.
161
+ """
162
+ if self._is_callable(attr):
163
+ raise AttributeError(f"Overwriting a callable attribute: {attr}")
164
+ self._send_set(attr, value)
165
+
166
+ def _send_get(self, attr: str, args: list, kwargs: dict):
167
+ """
168
+ Send a get request over the socket.
169
+ """
170
+ req = {"req": "get", "attr": attr, "args": args, "kwargs": kwargs}
171
+ self.socket.send(pickle.dumps(req))
172
+ return self._recv_result()
173
+
174
+ def _send_set(self, attr: str, value):
175
+ """
176
+ Send a set request over the socket.
177
+ """
178
+ req = {"req": "set", "attr": attr, "value": value}
179
+ self.socket.send(pickle.dumps(req))
180
+ return self._recv_result()
181
+
182
+ def _recv_result(self):
183
+ """
184
+ Receive a dictionary of {
185
+ "type": str,
186
+ "content": object,
187
+ }
188
+ if type == "exception", content is a dictionary of {
189
+ "exception": str,
190
+ "message": str,
191
+ "traceback": str,
192
+ }; re-raise the exception on the client side
193
+ if type == "result", content is the result
194
+ """
195
+ result = self.socket.recv()
196
+ result = pickle.loads(result)
197
+ if result["type"] == "exception":
198
+ raise RPCException(
199
+ result["content"]["exception"],
200
+ result["content"]["message"],
201
+ result["content"]["traceback"],
202
+ )
203
+ return result["content"]
204
+
205
+ def _is_callable(self, attr: str) -> bool:
206
+ """
207
+ Send a request to check if the attribute is callable.
208
+ Returns False if the attribute is not found.
209
+ """
210
+ if attr not in self._is_callable_cache:
211
+ req = {"req": "is_callable", "attr": attr}
212
+ self.socket.send(pickle.dumps(req))
213
+ result = self._recv_result()
214
+ self._is_callable_cache[attr] = result
215
+ return self._is_callable_cache[attr]
216
+
217
+ def __getattr__(self, attr: str):
218
+ """
219
+ Return the attribute of the same name.
220
+ If the attribute is a callable, return a function that sends the call over the socket.
221
+ Else, return the attribute value.
222
+ """
223
+ if self._is_callable(attr):
224
+ return lambda *args, **kwargs: self._send_get(attr, args, kwargs)
225
+ else:
226
+ return self._send_get(attr, [], {})
227
+
228
+ def __dir__(self):
229
+ """
230
+ Return a list of attributes.
231
+ """
232
+ req = {"req": "dir"}
233
+ self.socket.send(pickle.dumps(req))
234
+ result = self._recv_result()
235
+ return result + ["stop_server"]
236
+
237
+ def stop_server(self) -> bool:
238
+ """
239
+ Send a stop request to the server.
240
+ If the server is stopped, close the socket and terminate the context.
241
+ Returns a bool for success.
242
+ """
243
+ req = {"req": "stop"}
244
+ self.socket.send(pickle.dumps(req))
245
+ stopped = self._recv_result()
246
+ if stopped:
247
+ self.socket.close()
248
+ self.context.term()
249
+ return stopped
@@ -0,0 +1,71 @@
1
+ import math
2
+ import warnings
3
+
4
+ from torch.optim.lr_scheduler import LRScheduler
5
+
6
+
7
+ class CosineAnnealWithWarmupLR(LRScheduler):
8
+ """Sets the learning rate of each parameter group to the initial lr
9
+ times a given function. When last_epoch=-1, sets initial lr as lr.
10
+
11
+ Adapted from karpathy/nanoGPT
12
+ https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/train.py#L228
13
+
14
+ Args:
15
+ optimizer (Optimizer): Wrapped optimizer.
16
+ warmup_epochs (int): The number of epochs for warmup
17
+ lr_decay_epoch (int): The index of lr_decay epoch, at which we reach min_LR. Default: -1.
18
+ verbose (bool): If ``True``, prints a message to stdout for
19
+ each update. Default: ``False``.
20
+ min_lr_multiplier: The lowest lr multiplier the schedule will ever get to.
21
+
22
+ Example:
23
+ >>> warmup_epochs = 10
24
+ >>> lr_decay_epochs = 600
25
+ >>> scheduler = CosineAnnealWithWarmupLR(optimizer, warmup_epochs, lr_decay_epochs)
26
+ >>> for epoch in range(100):
27
+ >>> train(...)
28
+ >>> validate(...)
29
+ >>> scheduler.step()
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ optimizer,
35
+ warmup_epochs,
36
+ lr_decay_epochs,
37
+ last_epoch=-1,
38
+ min_lr_multiplier=0.1,
39
+ verbose=False,
40
+ ):
41
+ self.optimizer = optimizer
42
+ self.warmup_epochs = warmup_epochs
43
+ self.last_epoch = last_epoch
44
+ self.lr_decay_epoch = lr_decay_epochs
45
+ self.min_lr_multiplier = min_lr_multiplier
46
+
47
+ super().__init__(optimizer, last_epoch, verbose)
48
+
49
+ def get_lr(self):
50
+ if not self._get_lr_called_within_step:
51
+ warnings.warn(
52
+ "To get the last learning rate computed by the scheduler, "
53
+ "please use `get_last_lr()`."
54
+ )
55
+
56
+ return [
57
+ base_lr * self._calculate_lr_formula(self.last_epoch)
58
+ for base_lr in self.base_lrs
59
+ ]
60
+
61
+ def _calculate_lr_formula(self, last_epoch):
62
+ if last_epoch < self.warmup_epochs:
63
+ return (last_epoch + 1) / self.warmup_epochs
64
+ elif last_epoch > self.lr_decay_epoch:
65
+ return self.min_lr_multiplier
66
+ decay_ratio = (last_epoch - self.warmup_epochs) / (
67
+ self.lr_decay_epoch - self.warmup_epochs
68
+ )
69
+ assert 0 <= decay_ratio <= 1
70
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
71
+ return self.min_lr_multiplier + coeff * (1 - self.min_lr_multiplier)