egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
|
|
3
|
+
import einops
|
|
4
|
+
import matplotlib.gridspec as gridspec
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def generate_plots(ground_truths, predictions, sampled_images, to_plot=8, traj_index=0):
|
|
11
|
+
"""
|
|
12
|
+
Generates plots comparing ground truth vs predictions and displays sampled images.
|
|
13
|
+
|
|
14
|
+
Parameters:
|
|
15
|
+
- ground_truths: A numpy array of shape (T, 7) containing ground truth actions.
|
|
16
|
+
- predictions: A numpy array of shape (T, 7) with model's predictions.
|
|
17
|
+
- sampled_images: A numpy array containing M sampled images.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
# Initialize the figure
|
|
21
|
+
fig = plt.figure(figsize=(15, 15))
|
|
22
|
+
outer_grid = gridspec.GridSpec(8, 1, hspace=0.25, wspace=0.0)
|
|
23
|
+
|
|
24
|
+
# Top grid for images
|
|
25
|
+
M = sampled_images.shape[0]
|
|
26
|
+
chosen_indices = np.linspace(0, M - 1, to_plot, dtype=int)
|
|
27
|
+
chosen_images = sampled_images[chosen_indices]
|
|
28
|
+
top_grid = gridspec.GridSpecFromSubplotSpec(
|
|
29
|
+
1, len(chosen_images), subplot_spec=outer_grid[0, :], wspace=0.0
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
for i, img in enumerate(chosen_images):
|
|
33
|
+
ax = plt.Subplot(fig, top_grid[i])
|
|
34
|
+
ax.imshow(img)
|
|
35
|
+
ax.axis("off")
|
|
36
|
+
if i > 0: # To further remove any potential whitespace
|
|
37
|
+
ax.set_yticklabels([])
|
|
38
|
+
ax.set_xticklabels([])
|
|
39
|
+
ax.set_xticks([])
|
|
40
|
+
ax.set_yticks([])
|
|
41
|
+
fig.add_subplot(ax)
|
|
42
|
+
|
|
43
|
+
# For each action dimension, plot ground truth vs prediction
|
|
44
|
+
dim = ground_truths.shape[1]
|
|
45
|
+
for i in range(dim):
|
|
46
|
+
ax = plt.Subplot(fig, outer_grid[i + 1, :])
|
|
47
|
+
|
|
48
|
+
gt_values = ground_truths[:, i]
|
|
49
|
+
pred_values = predictions[:, i]
|
|
50
|
+
|
|
51
|
+
ax.plot(gt_values, "g", label="Ground Truth")
|
|
52
|
+
ax.plot(pred_values, "r--", label="Prediction")
|
|
53
|
+
|
|
54
|
+
ax.legend(loc="upper right")
|
|
55
|
+
ax.set_title(f"Act Dim {i + 1}", fontsize=16)
|
|
56
|
+
fig.add_subplot(ax)
|
|
57
|
+
|
|
58
|
+
# save the plot
|
|
59
|
+
fig.savefig(f"trajectory_{traj_index}.png")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@torch.no_grad()
|
|
63
|
+
def visualize_trajectory(
|
|
64
|
+
model,
|
|
65
|
+
test_dataset,
|
|
66
|
+
device,
|
|
67
|
+
buffer_size=6,
|
|
68
|
+
n_visualized_trajectories=5,
|
|
69
|
+
goal_conditional=False,
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Visualizes the trajectory of the model on the test dataset.
|
|
73
|
+
M images are sampled from the test dataset and displayed.
|
|
74
|
+
T = end - start is the length of the trajectory.
|
|
75
|
+
start and end (index) indicate the window of the trajectory to visualize.
|
|
76
|
+
"""
|
|
77
|
+
print(goal_conditional)
|
|
78
|
+
action_preds = []
|
|
79
|
+
ground_truth = []
|
|
80
|
+
images = []
|
|
81
|
+
image_buffer = collections.deque(maxlen=buffer_size)
|
|
82
|
+
test_dataset.set_include_trajectory_end(True)
|
|
83
|
+
|
|
84
|
+
i = 0
|
|
85
|
+
done_visualizing = 0
|
|
86
|
+
while (done_visualizing < n_visualized_trajectories) and (i < len(test_dataset)):
|
|
87
|
+
if goal_conditional:
|
|
88
|
+
(input_images, terminate), goals, *_, gt_actions = test_dataset[i]
|
|
89
|
+
else:
|
|
90
|
+
(input_images, terminate), *_, gt_actions = test_dataset[i]
|
|
91
|
+
input_images = input_images.float() / 255.0
|
|
92
|
+
image_buffer.append(input_images[-1])
|
|
93
|
+
img = input_images[-1]
|
|
94
|
+
images.append(einops.rearrange(img, "c h w -> h w c").cpu().detach().numpy())
|
|
95
|
+
ground_truth.append(gt_actions[-1])
|
|
96
|
+
if goal_conditional:
|
|
97
|
+
model_input = (
|
|
98
|
+
torch.stack(tuple(image_buffer), dim=0).unsqueeze(0).to(device),
|
|
99
|
+
torch.tensor(goals).unsqueeze(0).to(device),
|
|
100
|
+
torch.tensor(gt_actions).unsqueeze(0).to(device),
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
model_input = (
|
|
104
|
+
torch.stack(tuple(image_buffer), dim=0).unsqueeze(0).to(device),
|
|
105
|
+
torch.tensor(gt_actions).unsqueeze(0).to(device),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
out, _ = model.step(model_input)
|
|
109
|
+
action_preds.append(out.squeeze().cpu().detach().numpy())
|
|
110
|
+
|
|
111
|
+
if terminate:
|
|
112
|
+
action_preds = np.array(action_preds)
|
|
113
|
+
ground_truth = np.array(ground_truth)
|
|
114
|
+
images = np.array(images)
|
|
115
|
+
|
|
116
|
+
print(action_preds.shape, ground_truth.shape, images.shape)
|
|
117
|
+
|
|
118
|
+
generate_plots(
|
|
119
|
+
ground_truth, action_preds, images, traj_index=done_visualizing
|
|
120
|
+
)
|
|
121
|
+
done_visualizing += 1
|
|
122
|
+
# Reset everything.
|
|
123
|
+
action_preds = []
|
|
124
|
+
ground_truth = []
|
|
125
|
+
images = []
|
|
126
|
+
image_buffer = collections.deque(maxlen=buffer_size)
|
|
127
|
+
|
|
128
|
+
i += 1
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import zmq
|
|
2
|
+
import cv2
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pickle
|
|
5
|
+
import blosc as bl
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from abc import ABC
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# ZMQ Sockets
|
|
12
|
+
def create_push_socket(host, port):
|
|
13
|
+
context = zmq.Context()
|
|
14
|
+
socket = context.socket(zmq.PUSH)
|
|
15
|
+
socket.bind("tcp://{}:{}".format(host, port))
|
|
16
|
+
return socket
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_pull_socket(host, port):
|
|
20
|
+
context = zmq.Context()
|
|
21
|
+
socket = context.socket(zmq.PULL)
|
|
22
|
+
socket.setsockopt(zmq.CONFLATE, 1)
|
|
23
|
+
socket.bind("tcp://{}:{}".format(host, port))
|
|
24
|
+
return socket
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_response_socket(host, port):
|
|
28
|
+
content = zmq.Context()
|
|
29
|
+
socket = content.socket(zmq.REP)
|
|
30
|
+
socket.bind("tcp://{}:{}".format(host, port))
|
|
31
|
+
return socket
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def create_request_socket(host, port):
|
|
35
|
+
context = zmq.Context()
|
|
36
|
+
socket = context.socket(zmq.REQ)
|
|
37
|
+
socket.connect("tcp://{}:{}".format(host, port))
|
|
38
|
+
return socket
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Pub/Sub classes for Keypoints
|
|
42
|
+
class ZMQKeypointPublisher:
|
|
43
|
+
def __init__(self, host, port):
|
|
44
|
+
self._host, self._port = host, port
|
|
45
|
+
self._init_publisher()
|
|
46
|
+
|
|
47
|
+
def _init_publisher(self):
|
|
48
|
+
self.context = zmq.Context()
|
|
49
|
+
self.socket = self.context.socket(zmq.PUB)
|
|
50
|
+
self.socket.bind("tcp://{}:{}".format(self._host, self._port))
|
|
51
|
+
|
|
52
|
+
def pub_keypoints(self, keypoint_array, topic_name):
|
|
53
|
+
"""
|
|
54
|
+
Process the keypoints into a byte stream and input them in this function
|
|
55
|
+
"""
|
|
56
|
+
buffer = pickle.dumps(keypoint_array, protocol=-1)
|
|
57
|
+
self.socket.send(bytes("{} ".format(topic_name), "utf-8") + buffer)
|
|
58
|
+
|
|
59
|
+
def stop(self):
|
|
60
|
+
print("Closing the publisher socket in {}:{}.".format(self._host, self._port))
|
|
61
|
+
self.socket.close()
|
|
62
|
+
self.context.term()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Keypoint Subscriber
|
|
66
|
+
class ZMQKeypointSubscriber(threading.Thread):
|
|
67
|
+
def __init__(self, host, port, topic):
|
|
68
|
+
self._host, self._port, self._topic = host, port, topic
|
|
69
|
+
self._init_subscriber()
|
|
70
|
+
|
|
71
|
+
# Topic chars to remove
|
|
72
|
+
self.strip_value = bytes("{} ".format(self._topic), "utf-8")
|
|
73
|
+
|
|
74
|
+
def _init_subscriber(self):
|
|
75
|
+
self.context = zmq.Context()
|
|
76
|
+
self.socket = self.context.socket(zmq.SUB)
|
|
77
|
+
self.socket.setsockopt(zmq.CONFLATE, 1)
|
|
78
|
+
self.socket.connect("tcp://{}:{}".format(self._host, self._port))
|
|
79
|
+
self.socket.setsockopt(zmq.SUBSCRIBE, bytes(self._topic, "utf-8"))
|
|
80
|
+
|
|
81
|
+
def recv_keypoints(self, flags=None):
|
|
82
|
+
if flags is None:
|
|
83
|
+
raw_data = self.socket.recv()
|
|
84
|
+
raw_array = raw_data.lstrip(self.strip_value)
|
|
85
|
+
return pickle.loads(raw_array)
|
|
86
|
+
else: # For possible usage of no blocking zmq subscriber
|
|
87
|
+
try:
|
|
88
|
+
raw_data = self.socket.recv(flags)
|
|
89
|
+
raw_array = raw_data.lstrip(self.strip_value)
|
|
90
|
+
return pickle.loads(raw_array)
|
|
91
|
+
except zmq.Again:
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def stop(self):
|
|
95
|
+
print("Closing the subscriber socket in {}:{}.".format(self._host, self._port))
|
|
96
|
+
self.socket.close()
|
|
97
|
+
self.context.term()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# Pub/Sub classes for storing data from Realsense Cameras
|
|
101
|
+
class ZMQCameraPublisher:
|
|
102
|
+
def __init__(self, host, port):
|
|
103
|
+
self._host, self._port = host, port
|
|
104
|
+
self._init_publisher()
|
|
105
|
+
|
|
106
|
+
def _init_publisher(self):
|
|
107
|
+
self.context = zmq.Context()
|
|
108
|
+
self.socket = self.context.socket(zmq.PUB)
|
|
109
|
+
print("tcp://{}:{}".format(self._host, self._port))
|
|
110
|
+
self.socket.bind("tcp://{}:{}".format(self._host, self._port))
|
|
111
|
+
|
|
112
|
+
def pub_intrinsics(self, array):
|
|
113
|
+
self.socket.send(b"intrinsics " + pickle.dumps(array, protocol=-1))
|
|
114
|
+
|
|
115
|
+
def pub_rgb_image(self, rgb_image, timestamp):
|
|
116
|
+
_, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|
117
|
+
data = dict(timestamp=timestamp, rgb_image=buffer.tobytes())
|
|
118
|
+
self.socket.send(b"rgb_image " + pickle.dumps(data, protocol=-1))
|
|
119
|
+
|
|
120
|
+
def pub_depth_image(self, depth_image, timestamp):
|
|
121
|
+
compressed_depth = bl.pack_array(
|
|
122
|
+
depth_image, cname="zstd", clevel=1, shuffle=bl.NOSHUFFLE
|
|
123
|
+
)
|
|
124
|
+
data = dict(timestamp=timestamp, depth_image=compressed_depth)
|
|
125
|
+
self.socket.send(b"depth_image " + pickle.dumps(data, protocol=-1))
|
|
126
|
+
|
|
127
|
+
def pub_image_and_depth(self, rgb_image, depth_image, timestamp):
|
|
128
|
+
_, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|
129
|
+
compressed_depth = bl.pack_array(
|
|
130
|
+
depth_image, cname="zstd", clevel=1, shuffle=bl.NOSHUFFLE
|
|
131
|
+
)
|
|
132
|
+
data = dict(
|
|
133
|
+
timestamp=timestamp,
|
|
134
|
+
rgb_image=buffer.tobytes(),
|
|
135
|
+
depth_image=compressed_depth,
|
|
136
|
+
)
|
|
137
|
+
self.socket.send(b"image_and_depth " + pickle.dumps(data, protocol=-1))
|
|
138
|
+
|
|
139
|
+
def stop(self):
|
|
140
|
+
print("Closing the publisher socket in {}:{}.".format(self._host, self._port))
|
|
141
|
+
self.socket.close()
|
|
142
|
+
self.context.term()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class ZMQCameraSubscriber:
|
|
146
|
+
def __init__(self, host, port, topic_type):
|
|
147
|
+
self._host, self._port, self._topic_type = host, port, topic_type
|
|
148
|
+
self._init_subscriber()
|
|
149
|
+
|
|
150
|
+
def _init_subscriber(self):
|
|
151
|
+
self.context = zmq.Context()
|
|
152
|
+
self.socket = self.context.socket(zmq.SUB)
|
|
153
|
+
self.socket.setsockopt(zmq.CONFLATE, 1)
|
|
154
|
+
self.socket.connect("tcp://{}:{}".format(self._host, self._port))
|
|
155
|
+
|
|
156
|
+
if self._topic_type == "Intrinsics":
|
|
157
|
+
self.socket.setsockopt(zmq.SUBSCRIBE, b"intrinsics")
|
|
158
|
+
elif self._topic_type == "RGB":
|
|
159
|
+
self.socket.setsockopt(zmq.SUBSCRIBE, b"rgb_image")
|
|
160
|
+
elif self._topic_type == "Depth":
|
|
161
|
+
self.socket.setsockopt(zmq.SUBSCRIBE, b"depth_image")
|
|
162
|
+
elif self._topic_type == "RGBD":
|
|
163
|
+
self.socket.setsockopt(zmq.SUBSCRIBE, b"image_and_depth")
|
|
164
|
+
|
|
165
|
+
def recv_intrinsics(self):
|
|
166
|
+
raw_data = self.socket.recv()
|
|
167
|
+
raw_array = raw_data.lstrip(b"intrinsics ")
|
|
168
|
+
return pickle.loads(raw_array)
|
|
169
|
+
|
|
170
|
+
def recv_rgb_image(self):
|
|
171
|
+
raw_data = self.socket.recv()
|
|
172
|
+
data = raw_data.lstrip(b"rgb_image ")
|
|
173
|
+
data = pickle.loads(data)
|
|
174
|
+
encoded_data = np.frombuffer(data["rgb_image"], np.uint8)
|
|
175
|
+
return cv2.imdecode(encoded_data, 1), data["timestamp"]
|
|
176
|
+
|
|
177
|
+
def recv_depth_image(self):
|
|
178
|
+
raw_data = self.socket.recv()
|
|
179
|
+
striped_data = raw_data.lstrip(b"depth_image ")
|
|
180
|
+
data = pickle.loads(striped_data)
|
|
181
|
+
depth_image = bl.unpack_array(data["depth_image"])
|
|
182
|
+
return np.array(depth_image, dtype=np.float32), data["timestamp"]
|
|
183
|
+
|
|
184
|
+
def recv_image_and_depth(self):
|
|
185
|
+
raw_data = self.socket.recv()
|
|
186
|
+
striped_data = raw_data.lstrip(b"image_and_depth ")
|
|
187
|
+
data = pickle.loads(striped_data)
|
|
188
|
+
encoded_data = np.frombuffer(data["rgb_image"], np.uint8)
|
|
189
|
+
rgb_image = cv2.imdecode(encoded_data, 1)
|
|
190
|
+
depth_image = bl.unpack_array(data["depth_image"])
|
|
191
|
+
return rgb_image, np.array(depth_image, dtype=np.float32), data["timestamp"]
|
|
192
|
+
|
|
193
|
+
def stop(self):
|
|
194
|
+
print("Closing the subscriber socket in {}:{}.".format(self._host, self._port))
|
|
195
|
+
self.socket.close()
|
|
196
|
+
self.context.term()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# Publisher for image visualizers
|
|
200
|
+
class ZMQCompressedImageTransmitter(object):
|
|
201
|
+
def __init__(self, host, port):
|
|
202
|
+
self._host, self._port = host, port
|
|
203
|
+
# self._init_push_socket()
|
|
204
|
+
self._init_publisher()
|
|
205
|
+
|
|
206
|
+
def _init_publisher(self):
|
|
207
|
+
self.context = zmq.Context()
|
|
208
|
+
self.socket = self.context.socket(zmq.PUB)
|
|
209
|
+
self.socket.bind("tcp://{}:{}".format(self._host, self._port))
|
|
210
|
+
|
|
211
|
+
def _init_push_socket(self):
|
|
212
|
+
self.context = zmq.Context()
|
|
213
|
+
self.socket = self.context.socket(zmq.PUSH)
|
|
214
|
+
self.socket.bind("tcp://{}:{}".format(self._host, self._port))
|
|
215
|
+
|
|
216
|
+
def send_image(self, rgb_image):
|
|
217
|
+
_, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50])
|
|
218
|
+
self.socket.send(buffer.tobytes())
|
|
219
|
+
|
|
220
|
+
def stop(self):
|
|
221
|
+
print("Closing the publisher in {}:{}.".format(self._host, self._port))
|
|
222
|
+
self.socket.close()
|
|
223
|
+
self.context.term()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class ZMQCompressedImageReciever(threading.Thread):
|
|
227
|
+
def __init__(self, host, port):
|
|
228
|
+
self._host, self._port = host, port
|
|
229
|
+
# self._init_pull_socket()
|
|
230
|
+
self._init_subscriber()
|
|
231
|
+
|
|
232
|
+
def _init_subscriber(self):
|
|
233
|
+
self.context = zmq.Context()
|
|
234
|
+
self.socket = self.context.socket(zmq.SUB)
|
|
235
|
+
self.socket.setsockopt(zmq.CONFLATE, 1)
|
|
236
|
+
self.socket.connect("tcp://{}:{}".format(self._host, self._port))
|
|
237
|
+
self.socket.subscribe("")
|
|
238
|
+
|
|
239
|
+
def _init_pull_socket(self):
|
|
240
|
+
self.context = zmq.Context()
|
|
241
|
+
self.socket = self.context.socket(zmq.PULL)
|
|
242
|
+
self.socket.setsockopt(zmq.CONFLATE, 1)
|
|
243
|
+
self.socket.connect("tcp://{}:{}".format(self._host, self._port))
|
|
244
|
+
|
|
245
|
+
def recv_image(self):
|
|
246
|
+
raw_data = self.socket.recv()
|
|
247
|
+
encoded_data = np.frombuffer(raw_data, np.uint8)
|
|
248
|
+
decoded_frame = cv2.imdecode(encoded_data, 1)
|
|
249
|
+
return decoded_frame
|
|
250
|
+
|
|
251
|
+
def stop(self):
|
|
252
|
+
print("Closing the subscriber socket in {}:{}.".format(self._host, self._port))
|
|
253
|
+
self.socket.close()
|
|
254
|
+
self.context.term()
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class FrequencyTimer:
|
|
258
|
+
FREQ_1KHZ = 1e3
|
|
259
|
+
|
|
260
|
+
def __init__(self, frequency_rate):
|
|
261
|
+
self.time_available = 1e9 / frequency_rate
|
|
262
|
+
|
|
263
|
+
def start_loop(self):
|
|
264
|
+
self.start_time = time.time_ns()
|
|
265
|
+
|
|
266
|
+
def end_loop(self):
|
|
267
|
+
wait_time = self.time_available + self.start_time
|
|
268
|
+
|
|
269
|
+
while time.time_ns() < wait_time:
|
|
270
|
+
time.sleep(1 / FrequencyTimer.FREQ_1KHZ)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class ProcessInstantiator(ABC):
|
|
274
|
+
def __init__(self):
|
|
275
|
+
self.processes = []
|
|
276
|
+
|
|
277
|
+
def _start_component(self, configs):
|
|
278
|
+
raise NotImplementedError("Function not implemented!")
|
|
279
|
+
|
|
280
|
+
def get_processes(self):
|
|
281
|
+
return self.processes
|
baselines/rum_policy.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import numpy as np
|
|
5
|
+
from scipy.spatial.transform import Rotation as R
|
|
6
|
+
from egogym.policies.base_policy import BasePolicy
|
|
7
|
+
from baselines.rum.policy import Policy
|
|
8
|
+
|
|
9
|
+
class RUMPolicy(BasePolicy):
|
|
10
|
+
|
|
11
|
+
def __init__(self, config=None, policy=None, **kwargs):
|
|
12
|
+
super().__init__()
|
|
13
|
+
|
|
14
|
+
if config is None:
|
|
15
|
+
config = get_config()
|
|
16
|
+
|
|
17
|
+
for key, value in kwargs.items():
|
|
18
|
+
config[key] = value
|
|
19
|
+
|
|
20
|
+
self.config = config
|
|
21
|
+
self.name = "rum"
|
|
22
|
+
self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
|
23
|
+
if policy is not None:
|
|
24
|
+
self.policy = policy
|
|
25
|
+
else:
|
|
26
|
+
self.policy = Policy(model_path=config["model_path"], device=self.device)
|
|
27
|
+
self.gripper_threshold = config["gripper_threshold"]
|
|
28
|
+
self.grasping_style = config["grasping_style"]
|
|
29
|
+
self.num_envs = None
|
|
30
|
+
self.grasped = None
|
|
31
|
+
self.desired_axis_t = torch.tensor([
|
|
32
|
+
[-1, 0, 0],
|
|
33
|
+
[ 0, 0, -1],
|
|
34
|
+
[ 0, -1, 0]
|
|
35
|
+
], dtype=torch.float32, device=self.device)
|
|
36
|
+
self.grasp_override_t = torch.tensor([0.00, 0.18, 0.04], dtype=torch.float32, device=self.device)
|
|
37
|
+
|
|
38
|
+
def get_action(self, obs):
|
|
39
|
+
|
|
40
|
+
if self.num_envs is None:
|
|
41
|
+
if len(obs["rgb_ego"].shape) == 4:
|
|
42
|
+
self.num_envs = obs["rgb_ego"].shape[0]
|
|
43
|
+
else:
|
|
44
|
+
self.num_envs = 1
|
|
45
|
+
self.grasped = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device)
|
|
46
|
+
|
|
47
|
+
if self.num_envs == 1:
|
|
48
|
+
obs = {key: obs[key][np.newaxis, ...] if isinstance(obs[key], np.ndarray) else obs[key] for key in obs}
|
|
49
|
+
|
|
50
|
+
T_world_camera = torch.from_numpy(obs["camera_pose"]).view(self.num_envs, 4, 4).to(self.device)
|
|
51
|
+
if "object_pose" in obs:
|
|
52
|
+
T_world_object = torch.from_numpy(obs["object_pose"]).view(self.num_envs, 4, 4).to(self.device)
|
|
53
|
+
else:
|
|
54
|
+
T_world_object = torch.from_numpy(obs["handle_pose"]).view(self.num_envs, 4, 4).to(self.device)
|
|
55
|
+
|
|
56
|
+
R_world_camera = T_world_camera[:, :3, :3]
|
|
57
|
+
t_world_camera = T_world_camera[:, :3, 3]
|
|
58
|
+
p_world = T_world_object[:, :3, 3]
|
|
59
|
+
R_camera_world = R_world_camera.transpose(1, 2)
|
|
60
|
+
p_camera = torch.bmm(R_camera_world, (p_world - t_world_camera).unsqueeze(-1)).squeeze(-1)
|
|
61
|
+
object_3d_position = p_camera @ self.desired_axis_t.T
|
|
62
|
+
object_3d_position[self.grasped] = self.grasp_override_t
|
|
63
|
+
|
|
64
|
+
rgb_tensor = torch.from_numpy(obs["rgb_ego"]).permute(0, 3, 1, 2).to(self.device).float()
|
|
65
|
+
rgb_resized = F.interpolate(rgb_tensor, size=(224, 224), mode='bilinear', align_corners=False)
|
|
66
|
+
rgb_ego = rgb_resized.permute(0, 2, 3, 1).byte().cpu().numpy()
|
|
67
|
+
|
|
68
|
+
rum_obs = {
|
|
69
|
+
"rgb_ego": rgb_ego,
|
|
70
|
+
"object_3d_position": object_3d_position.cpu().numpy(),
|
|
71
|
+
}
|
|
72
|
+
action_tensors = self.policy.infer(rum_obs)
|
|
73
|
+
|
|
74
|
+
action_tensors_t = torch.from_numpy(action_tensors).to(self.device)
|
|
75
|
+
R_action = R.from_euler("xyz", action_tensors[:, 3:6]).as_matrix()
|
|
76
|
+
R_action_t = torch.from_numpy(R_action).to(self.device, dtype=torch.float32)
|
|
77
|
+
|
|
78
|
+
action_pose = torch.eye(4, device=self.device).unsqueeze(0).repeat(self.num_envs, 1, 1)
|
|
79
|
+
action_pose[:, :3, :3] = R_action_t
|
|
80
|
+
action_pose[:, :3, 3] = action_tensors_t[:, :3]
|
|
81
|
+
action_pose[:, :3, :3] = self.desired_axis_t.T @ action_pose[:, :3, :3] @ self.desired_axis_t
|
|
82
|
+
action_pose[:, :3, 3] = action_pose[:, :3, 3] @ self.desired_axis_t
|
|
83
|
+
|
|
84
|
+
self.grasped = torch.maximum(self.grasped, action_tensors_t[:, 6] < self.gripper_threshold)
|
|
85
|
+
grasp_action = self.grasped.unsqueeze(1).float()
|
|
86
|
+
if self.grasping_style == "continuous":
|
|
87
|
+
grasp_action = 1 - action_tensors_t[:, 6].unsqueeze(1)
|
|
88
|
+
|
|
89
|
+
actions = torch.cat([action_pose.view(self.num_envs, -1), grasp_action], dim=1)
|
|
90
|
+
|
|
91
|
+
return actions.cpu().numpy()
|
|
92
|
+
|
|
93
|
+
def reset(self, indicies=None):
|
|
94
|
+
self.policy.reset(indicies)
|
|
95
|
+
if indicies is not None and self.num_envs is not None:
|
|
96
|
+
self.grasped[indicies] = False
|
|
97
|
+
elif self.num_envs is not None:
|
|
98
|
+
self.grasped = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_config():
|
|
102
|
+
config = {
|
|
103
|
+
"model_path": "checkpoints/rum_pick.pt",
|
|
104
|
+
"gripper_threshold": 0.7,
|
|
105
|
+
"grasping_style": "binary",
|
|
106
|
+
"device": None,
|
|
107
|
+
}
|
|
108
|
+
return config
|
egogym/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
import gymnasium as gym
|
|
2
|
+
from .tasks.pick import PickTask
|
|
3
|
+
from .tasks.open import OpenTask
|
|
4
|
+
from .tasks.close import CloseTask
|
|
5
|
+
|
|
6
|
+
gym.register(id="Egogym-Pick-v0", entry_point=PickTask)
|
|
7
|
+
gym.register(id="Egogym-Open-v0", entry_point=OpenTask)
|
|
8
|
+
gym.register(id="Egogym-Close-v0", entry_point=CloseTask)
|