egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,128 @@
1
+ import collections
2
+
3
+ import einops
4
+ import matplotlib.gridspec as gridspec
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ def generate_plots(ground_truths, predictions, sampled_images, to_plot=8, traj_index=0):
11
+ """
12
+ Generates plots comparing ground truth vs predictions and displays sampled images.
13
+
14
+ Parameters:
15
+ - ground_truths: A numpy array of shape (T, 7) containing ground truth actions.
16
+ - predictions: A numpy array of shape (T, 7) with model's predictions.
17
+ - sampled_images: A numpy array containing M sampled images.
18
+ """
19
+
20
+ # Initialize the figure
21
+ fig = plt.figure(figsize=(15, 15))
22
+ outer_grid = gridspec.GridSpec(8, 1, hspace=0.25, wspace=0.0)
23
+
24
+ # Top grid for images
25
+ M = sampled_images.shape[0]
26
+ chosen_indices = np.linspace(0, M - 1, to_plot, dtype=int)
27
+ chosen_images = sampled_images[chosen_indices]
28
+ top_grid = gridspec.GridSpecFromSubplotSpec(
29
+ 1, len(chosen_images), subplot_spec=outer_grid[0, :], wspace=0.0
30
+ )
31
+
32
+ for i, img in enumerate(chosen_images):
33
+ ax = plt.Subplot(fig, top_grid[i])
34
+ ax.imshow(img)
35
+ ax.axis("off")
36
+ if i > 0: # To further remove any potential whitespace
37
+ ax.set_yticklabels([])
38
+ ax.set_xticklabels([])
39
+ ax.set_xticks([])
40
+ ax.set_yticks([])
41
+ fig.add_subplot(ax)
42
+
43
+ # For each action dimension, plot ground truth vs prediction
44
+ dim = ground_truths.shape[1]
45
+ for i in range(dim):
46
+ ax = plt.Subplot(fig, outer_grid[i + 1, :])
47
+
48
+ gt_values = ground_truths[:, i]
49
+ pred_values = predictions[:, i]
50
+
51
+ ax.plot(gt_values, "g", label="Ground Truth")
52
+ ax.plot(pred_values, "r--", label="Prediction")
53
+
54
+ ax.legend(loc="upper right")
55
+ ax.set_title(f"Act Dim {i + 1}", fontsize=16)
56
+ fig.add_subplot(ax)
57
+
58
+ # save the plot
59
+ fig.savefig(f"trajectory_{traj_index}.png")
60
+
61
+
62
+ @torch.no_grad()
63
+ def visualize_trajectory(
64
+ model,
65
+ test_dataset,
66
+ device,
67
+ buffer_size=6,
68
+ n_visualized_trajectories=5,
69
+ goal_conditional=False,
70
+ ):
71
+ """
72
+ Visualizes the trajectory of the model on the test dataset.
73
+ M images are sampled from the test dataset and displayed.
74
+ T = end - start is the length of the trajectory.
75
+ start and end (index) indicate the window of the trajectory to visualize.
76
+ """
77
+ print(goal_conditional)
78
+ action_preds = []
79
+ ground_truth = []
80
+ images = []
81
+ image_buffer = collections.deque(maxlen=buffer_size)
82
+ test_dataset.set_include_trajectory_end(True)
83
+
84
+ i = 0
85
+ done_visualizing = 0
86
+ while (done_visualizing < n_visualized_trajectories) and (i < len(test_dataset)):
87
+ if goal_conditional:
88
+ (input_images, terminate), goals, *_, gt_actions = test_dataset[i]
89
+ else:
90
+ (input_images, terminate), *_, gt_actions = test_dataset[i]
91
+ input_images = input_images.float() / 255.0
92
+ image_buffer.append(input_images[-1])
93
+ img = input_images[-1]
94
+ images.append(einops.rearrange(img, "c h w -> h w c").cpu().detach().numpy())
95
+ ground_truth.append(gt_actions[-1])
96
+ if goal_conditional:
97
+ model_input = (
98
+ torch.stack(tuple(image_buffer), dim=0).unsqueeze(0).to(device),
99
+ torch.tensor(goals).unsqueeze(0).to(device),
100
+ torch.tensor(gt_actions).unsqueeze(0).to(device),
101
+ )
102
+ else:
103
+ model_input = (
104
+ torch.stack(tuple(image_buffer), dim=0).unsqueeze(0).to(device),
105
+ torch.tensor(gt_actions).unsqueeze(0).to(device),
106
+ )
107
+
108
+ out, _ = model.step(model_input)
109
+ action_preds.append(out.squeeze().cpu().detach().numpy())
110
+
111
+ if terminate:
112
+ action_preds = np.array(action_preds)
113
+ ground_truth = np.array(ground_truth)
114
+ images = np.array(images)
115
+
116
+ print(action_preds.shape, ground_truth.shape, images.shape)
117
+
118
+ generate_plots(
119
+ ground_truth, action_preds, images, traj_index=done_visualizing
120
+ )
121
+ done_visualizing += 1
122
+ # Reset everything.
123
+ action_preds = []
124
+ ground_truth = []
125
+ images = []
126
+ image_buffer = collections.deque(maxlen=buffer_size)
127
+
128
+ i += 1
@@ -0,0 +1,281 @@
1
+ import zmq
2
+ import cv2
3
+ import numpy as np
4
+ import pickle
5
+ import blosc as bl
6
+ import threading
7
+ import time
8
+ from abc import ABC
9
+
10
+
11
+ # ZMQ Sockets
12
+ def create_push_socket(host, port):
13
+ context = zmq.Context()
14
+ socket = context.socket(zmq.PUSH)
15
+ socket.bind("tcp://{}:{}".format(host, port))
16
+ return socket
17
+
18
+
19
+ def create_pull_socket(host, port):
20
+ context = zmq.Context()
21
+ socket = context.socket(zmq.PULL)
22
+ socket.setsockopt(zmq.CONFLATE, 1)
23
+ socket.bind("tcp://{}:{}".format(host, port))
24
+ return socket
25
+
26
+
27
+ def create_response_socket(host, port):
28
+ content = zmq.Context()
29
+ socket = content.socket(zmq.REP)
30
+ socket.bind("tcp://{}:{}".format(host, port))
31
+ return socket
32
+
33
+
34
+ def create_request_socket(host, port):
35
+ context = zmq.Context()
36
+ socket = context.socket(zmq.REQ)
37
+ socket.connect("tcp://{}:{}".format(host, port))
38
+ return socket
39
+
40
+
41
+ # Pub/Sub classes for Keypoints
42
+ class ZMQKeypointPublisher:
43
+ def __init__(self, host, port):
44
+ self._host, self._port = host, port
45
+ self._init_publisher()
46
+
47
+ def _init_publisher(self):
48
+ self.context = zmq.Context()
49
+ self.socket = self.context.socket(zmq.PUB)
50
+ self.socket.bind("tcp://{}:{}".format(self._host, self._port))
51
+
52
+ def pub_keypoints(self, keypoint_array, topic_name):
53
+ """
54
+ Process the keypoints into a byte stream and input them in this function
55
+ """
56
+ buffer = pickle.dumps(keypoint_array, protocol=-1)
57
+ self.socket.send(bytes("{} ".format(topic_name), "utf-8") + buffer)
58
+
59
+ def stop(self):
60
+ print("Closing the publisher socket in {}:{}.".format(self._host, self._port))
61
+ self.socket.close()
62
+ self.context.term()
63
+
64
+
65
+ # Keypoint Subscriber
66
+ class ZMQKeypointSubscriber(threading.Thread):
67
+ def __init__(self, host, port, topic):
68
+ self._host, self._port, self._topic = host, port, topic
69
+ self._init_subscriber()
70
+
71
+ # Topic chars to remove
72
+ self.strip_value = bytes("{} ".format(self._topic), "utf-8")
73
+
74
+ def _init_subscriber(self):
75
+ self.context = zmq.Context()
76
+ self.socket = self.context.socket(zmq.SUB)
77
+ self.socket.setsockopt(zmq.CONFLATE, 1)
78
+ self.socket.connect("tcp://{}:{}".format(self._host, self._port))
79
+ self.socket.setsockopt(zmq.SUBSCRIBE, bytes(self._topic, "utf-8"))
80
+
81
+ def recv_keypoints(self, flags=None):
82
+ if flags is None:
83
+ raw_data = self.socket.recv()
84
+ raw_array = raw_data.lstrip(self.strip_value)
85
+ return pickle.loads(raw_array)
86
+ else: # For possible usage of no blocking zmq subscriber
87
+ try:
88
+ raw_data = self.socket.recv(flags)
89
+ raw_array = raw_data.lstrip(self.strip_value)
90
+ return pickle.loads(raw_array)
91
+ except zmq.Again:
92
+ return None
93
+
94
+ def stop(self):
95
+ print("Closing the subscriber socket in {}:{}.".format(self._host, self._port))
96
+ self.socket.close()
97
+ self.context.term()
98
+
99
+
100
+ # Pub/Sub classes for storing data from Realsense Cameras
101
+ class ZMQCameraPublisher:
102
+ def __init__(self, host, port):
103
+ self._host, self._port = host, port
104
+ self._init_publisher()
105
+
106
+ def _init_publisher(self):
107
+ self.context = zmq.Context()
108
+ self.socket = self.context.socket(zmq.PUB)
109
+ print("tcp://{}:{}".format(self._host, self._port))
110
+ self.socket.bind("tcp://{}:{}".format(self._host, self._port))
111
+
112
+ def pub_intrinsics(self, array):
113
+ self.socket.send(b"intrinsics " + pickle.dumps(array, protocol=-1))
114
+
115
+ def pub_rgb_image(self, rgb_image, timestamp):
116
+ _, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
117
+ data = dict(timestamp=timestamp, rgb_image=buffer.tobytes())
118
+ self.socket.send(b"rgb_image " + pickle.dumps(data, protocol=-1))
119
+
120
+ def pub_depth_image(self, depth_image, timestamp):
121
+ compressed_depth = bl.pack_array(
122
+ depth_image, cname="zstd", clevel=1, shuffle=bl.NOSHUFFLE
123
+ )
124
+ data = dict(timestamp=timestamp, depth_image=compressed_depth)
125
+ self.socket.send(b"depth_image " + pickle.dumps(data, protocol=-1))
126
+
127
+ def pub_image_and_depth(self, rgb_image, depth_image, timestamp):
128
+ _, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
129
+ compressed_depth = bl.pack_array(
130
+ depth_image, cname="zstd", clevel=1, shuffle=bl.NOSHUFFLE
131
+ )
132
+ data = dict(
133
+ timestamp=timestamp,
134
+ rgb_image=buffer.tobytes(),
135
+ depth_image=compressed_depth,
136
+ )
137
+ self.socket.send(b"image_and_depth " + pickle.dumps(data, protocol=-1))
138
+
139
+ def stop(self):
140
+ print("Closing the publisher socket in {}:{}.".format(self._host, self._port))
141
+ self.socket.close()
142
+ self.context.term()
143
+
144
+
145
+ class ZMQCameraSubscriber:
146
+ def __init__(self, host, port, topic_type):
147
+ self._host, self._port, self._topic_type = host, port, topic_type
148
+ self._init_subscriber()
149
+
150
+ def _init_subscriber(self):
151
+ self.context = zmq.Context()
152
+ self.socket = self.context.socket(zmq.SUB)
153
+ self.socket.setsockopt(zmq.CONFLATE, 1)
154
+ self.socket.connect("tcp://{}:{}".format(self._host, self._port))
155
+
156
+ if self._topic_type == "Intrinsics":
157
+ self.socket.setsockopt(zmq.SUBSCRIBE, b"intrinsics")
158
+ elif self._topic_type == "RGB":
159
+ self.socket.setsockopt(zmq.SUBSCRIBE, b"rgb_image")
160
+ elif self._topic_type == "Depth":
161
+ self.socket.setsockopt(zmq.SUBSCRIBE, b"depth_image")
162
+ elif self._topic_type == "RGBD":
163
+ self.socket.setsockopt(zmq.SUBSCRIBE, b"image_and_depth")
164
+
165
+ def recv_intrinsics(self):
166
+ raw_data = self.socket.recv()
167
+ raw_array = raw_data.lstrip(b"intrinsics ")
168
+ return pickle.loads(raw_array)
169
+
170
+ def recv_rgb_image(self):
171
+ raw_data = self.socket.recv()
172
+ data = raw_data.lstrip(b"rgb_image ")
173
+ data = pickle.loads(data)
174
+ encoded_data = np.frombuffer(data["rgb_image"], np.uint8)
175
+ return cv2.imdecode(encoded_data, 1), data["timestamp"]
176
+
177
+ def recv_depth_image(self):
178
+ raw_data = self.socket.recv()
179
+ striped_data = raw_data.lstrip(b"depth_image ")
180
+ data = pickle.loads(striped_data)
181
+ depth_image = bl.unpack_array(data["depth_image"])
182
+ return np.array(depth_image, dtype=np.float32), data["timestamp"]
183
+
184
+ def recv_image_and_depth(self):
185
+ raw_data = self.socket.recv()
186
+ striped_data = raw_data.lstrip(b"image_and_depth ")
187
+ data = pickle.loads(striped_data)
188
+ encoded_data = np.frombuffer(data["rgb_image"], np.uint8)
189
+ rgb_image = cv2.imdecode(encoded_data, 1)
190
+ depth_image = bl.unpack_array(data["depth_image"])
191
+ return rgb_image, np.array(depth_image, dtype=np.float32), data["timestamp"]
192
+
193
+ def stop(self):
194
+ print("Closing the subscriber socket in {}:{}.".format(self._host, self._port))
195
+ self.socket.close()
196
+ self.context.term()
197
+
198
+
199
+ # Publisher for image visualizers
200
+ class ZMQCompressedImageTransmitter(object):
201
+ def __init__(self, host, port):
202
+ self._host, self._port = host, port
203
+ # self._init_push_socket()
204
+ self._init_publisher()
205
+
206
+ def _init_publisher(self):
207
+ self.context = zmq.Context()
208
+ self.socket = self.context.socket(zmq.PUB)
209
+ self.socket.bind("tcp://{}:{}".format(self._host, self._port))
210
+
211
+ def _init_push_socket(self):
212
+ self.context = zmq.Context()
213
+ self.socket = self.context.socket(zmq.PUSH)
214
+ self.socket.bind("tcp://{}:{}".format(self._host, self._port))
215
+
216
+ def send_image(self, rgb_image):
217
+ _, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50])
218
+ self.socket.send(buffer.tobytes())
219
+
220
+ def stop(self):
221
+ print("Closing the publisher in {}:{}.".format(self._host, self._port))
222
+ self.socket.close()
223
+ self.context.term()
224
+
225
+
226
+ class ZMQCompressedImageReciever(threading.Thread):
227
+ def __init__(self, host, port):
228
+ self._host, self._port = host, port
229
+ # self._init_pull_socket()
230
+ self._init_subscriber()
231
+
232
+ def _init_subscriber(self):
233
+ self.context = zmq.Context()
234
+ self.socket = self.context.socket(zmq.SUB)
235
+ self.socket.setsockopt(zmq.CONFLATE, 1)
236
+ self.socket.connect("tcp://{}:{}".format(self._host, self._port))
237
+ self.socket.subscribe("")
238
+
239
+ def _init_pull_socket(self):
240
+ self.context = zmq.Context()
241
+ self.socket = self.context.socket(zmq.PULL)
242
+ self.socket.setsockopt(zmq.CONFLATE, 1)
243
+ self.socket.connect("tcp://{}:{}".format(self._host, self._port))
244
+
245
+ def recv_image(self):
246
+ raw_data = self.socket.recv()
247
+ encoded_data = np.frombuffer(raw_data, np.uint8)
248
+ decoded_frame = cv2.imdecode(encoded_data, 1)
249
+ return decoded_frame
250
+
251
+ def stop(self):
252
+ print("Closing the subscriber socket in {}:{}.".format(self._host, self._port))
253
+ self.socket.close()
254
+ self.context.term()
255
+
256
+
257
+ class FrequencyTimer:
258
+ FREQ_1KHZ = 1e3
259
+
260
+ def __init__(self, frequency_rate):
261
+ self.time_available = 1e9 / frequency_rate
262
+
263
+ def start_loop(self):
264
+ self.start_time = time.time_ns()
265
+
266
+ def end_loop(self):
267
+ wait_time = self.time_available + self.start_time
268
+
269
+ while time.time_ns() < wait_time:
270
+ time.sleep(1 / FrequencyTimer.FREQ_1KHZ)
271
+
272
+
273
+ class ProcessInstantiator(ABC):
274
+ def __init__(self):
275
+ self.processes = []
276
+
277
+ def _start_component(self, configs):
278
+ raise NotImplementedError("Function not implemented!")
279
+
280
+ def get_processes(self):
281
+ return self.processes
@@ -0,0 +1,108 @@
1
+ import cv2
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from scipy.spatial.transform import Rotation as R
6
+ from egogym.policies.base_policy import BasePolicy
7
+ from baselines.rum.policy import Policy
8
+
9
+ class RUMPolicy(BasePolicy):
10
+
11
+ def __init__(self, config=None, policy=None, **kwargs):
12
+ super().__init__()
13
+
14
+ if config is None:
15
+ config = get_config()
16
+
17
+ for key, value in kwargs.items():
18
+ config[key] = value
19
+
20
+ self.config = config
21
+ self.name = "rum"
22
+ self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
23
+ if policy is not None:
24
+ self.policy = policy
25
+ else:
26
+ self.policy = Policy(model_path=config["model_path"], device=self.device)
27
+ self.gripper_threshold = config["gripper_threshold"]
28
+ self.grasping_style = config["grasping_style"]
29
+ self.num_envs = None
30
+ self.grasped = None
31
+ self.desired_axis_t = torch.tensor([
32
+ [-1, 0, 0],
33
+ [ 0, 0, -1],
34
+ [ 0, -1, 0]
35
+ ], dtype=torch.float32, device=self.device)
36
+ self.grasp_override_t = torch.tensor([0.00, 0.18, 0.04], dtype=torch.float32, device=self.device)
37
+
38
+ def get_action(self, obs):
39
+
40
+ if self.num_envs is None:
41
+ if len(obs["rgb_ego"].shape) == 4:
42
+ self.num_envs = obs["rgb_ego"].shape[0]
43
+ else:
44
+ self.num_envs = 1
45
+ self.grasped = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device)
46
+
47
+ if self.num_envs == 1:
48
+ obs = {key: obs[key][np.newaxis, ...] if isinstance(obs[key], np.ndarray) else obs[key] for key in obs}
49
+
50
+ T_world_camera = torch.from_numpy(obs["camera_pose"]).view(self.num_envs, 4, 4).to(self.device)
51
+ if "object_pose" in obs:
52
+ T_world_object = torch.from_numpy(obs["object_pose"]).view(self.num_envs, 4, 4).to(self.device)
53
+ else:
54
+ T_world_object = torch.from_numpy(obs["handle_pose"]).view(self.num_envs, 4, 4).to(self.device)
55
+
56
+ R_world_camera = T_world_camera[:, :3, :3]
57
+ t_world_camera = T_world_camera[:, :3, 3]
58
+ p_world = T_world_object[:, :3, 3]
59
+ R_camera_world = R_world_camera.transpose(1, 2)
60
+ p_camera = torch.bmm(R_camera_world, (p_world - t_world_camera).unsqueeze(-1)).squeeze(-1)
61
+ object_3d_position = p_camera @ self.desired_axis_t.T
62
+ object_3d_position[self.grasped] = self.grasp_override_t
63
+
64
+ rgb_tensor = torch.from_numpy(obs["rgb_ego"]).permute(0, 3, 1, 2).to(self.device).float()
65
+ rgb_resized = F.interpolate(rgb_tensor, size=(224, 224), mode='bilinear', align_corners=False)
66
+ rgb_ego = rgb_resized.permute(0, 2, 3, 1).byte().cpu().numpy()
67
+
68
+ rum_obs = {
69
+ "rgb_ego": rgb_ego,
70
+ "object_3d_position": object_3d_position.cpu().numpy(),
71
+ }
72
+ action_tensors = self.policy.infer(rum_obs)
73
+
74
+ action_tensors_t = torch.from_numpy(action_tensors).to(self.device)
75
+ R_action = R.from_euler("xyz", action_tensors[:, 3:6]).as_matrix()
76
+ R_action_t = torch.from_numpy(R_action).to(self.device, dtype=torch.float32)
77
+
78
+ action_pose = torch.eye(4, device=self.device).unsqueeze(0).repeat(self.num_envs, 1, 1)
79
+ action_pose[:, :3, :3] = R_action_t
80
+ action_pose[:, :3, 3] = action_tensors_t[:, :3]
81
+ action_pose[:, :3, :3] = self.desired_axis_t.T @ action_pose[:, :3, :3] @ self.desired_axis_t
82
+ action_pose[:, :3, 3] = action_pose[:, :3, 3] @ self.desired_axis_t
83
+
84
+ self.grasped = torch.maximum(self.grasped, action_tensors_t[:, 6] < self.gripper_threshold)
85
+ grasp_action = self.grasped.unsqueeze(1).float()
86
+ if self.grasping_style == "continuous":
87
+ grasp_action = 1 - action_tensors_t[:, 6].unsqueeze(1)
88
+
89
+ actions = torch.cat([action_pose.view(self.num_envs, -1), grasp_action], dim=1)
90
+
91
+ return actions.cpu().numpy()
92
+
93
+ def reset(self, indicies=None):
94
+ self.policy.reset(indicies)
95
+ if indicies is not None and self.num_envs is not None:
96
+ self.grasped[indicies] = False
97
+ elif self.num_envs is not None:
98
+ self.grasped = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device)
99
+
100
+
101
+ def get_config():
102
+ config = {
103
+ "model_path": "checkpoints/rum_pick.pt",
104
+ "gripper_threshold": 0.7,
105
+ "grasping_style": "binary",
106
+ "device": None,
107
+ }
108
+ return config
egogym/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ import gymnasium as gym
2
+ from .tasks.pick import PickTask
3
+ from .tasks.open import OpenTask
4
+ from .tasks.close import CloseTask
5
+
6
+ gym.register(id="Egogym-Pick-v0", entry_point=PickTask)
7
+ gym.register(id="Egogym-Open-v0", entry_point=OpenTask)
8
+ gym.register(id="Egogym-Close-v0", entry_point=CloseTask)