egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torchvision.transforms as T
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OpenLoopReplay(nn.Module):
|
|
8
|
+
def __init__(self, encoder, k=5, enc_weight_pth=None, use_vinn=False, cfg=None):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.encoder = encoder
|
|
11
|
+
self.cfg = cfg
|
|
12
|
+
self.use_vinn = use_vinn
|
|
13
|
+
|
|
14
|
+
self.k = k
|
|
15
|
+
|
|
16
|
+
if enc_weight_pth is not None:
|
|
17
|
+
self.encoder.load_state_dict(
|
|
18
|
+
torch.load(enc_weight_pth, map_location="cpu")["model"]
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
self.representations = None
|
|
22
|
+
self.actions = None
|
|
23
|
+
self.imgs = None
|
|
24
|
+
softmax = nn.Softmax(dim=1)
|
|
25
|
+
self.dist_scale_func = lambda x: (softmax(-x))
|
|
26
|
+
self.encoder.eval()
|
|
27
|
+
self.device = "cpu"
|
|
28
|
+
self.encoder.to(self.device)
|
|
29
|
+
self.img_transform = T.Resize((256, 256), antialias=True)
|
|
30
|
+
|
|
31
|
+
self.open_loop = False
|
|
32
|
+
self.idx = 0
|
|
33
|
+
|
|
34
|
+
def to(self, device):
|
|
35
|
+
self.device = device
|
|
36
|
+
self.encoder.to(device)
|
|
37
|
+
|
|
38
|
+
return super().to(device)
|
|
39
|
+
|
|
40
|
+
def set_dataset(self, dataloader):
|
|
41
|
+
self.train_dataset = dataloader.dataset
|
|
42
|
+
if self.use_vinn:
|
|
43
|
+
for i, (image, label) in tqdm(enumerate(dataloader)):
|
|
44
|
+
image = image.float() / 255.0
|
|
45
|
+
image = image.to(self.device)
|
|
46
|
+
label = torch.Tensor(label).to("cpu").detach().squeeze()
|
|
47
|
+
|
|
48
|
+
x = (image, label)
|
|
49
|
+
representation = self.encoder(x).to("cpu").detach().squeeze(dim=1)
|
|
50
|
+
if self.representations is None:
|
|
51
|
+
self.representations = representation
|
|
52
|
+
self.actions = label
|
|
53
|
+
image = image.to("cpu").detach().numpy()
|
|
54
|
+
self.imgs = list(image)
|
|
55
|
+
else:
|
|
56
|
+
self.representations = torch.cat(
|
|
57
|
+
(self.representations, representation), 0
|
|
58
|
+
)
|
|
59
|
+
self.actions = torch.cat((self.actions, label), 0)
|
|
60
|
+
image = image.to("cpu").detach().numpy()
|
|
61
|
+
self.imgs.extend(list(image))
|
|
62
|
+
|
|
63
|
+
def step(self, img, **kwargs):
|
|
64
|
+
logs = {}
|
|
65
|
+
print(self.idx)
|
|
66
|
+
print(len(self.train_dataset))
|
|
67
|
+
if self.use_vinn:
|
|
68
|
+
normalized_image = self.img_transform(img[0].squeeze(0))
|
|
69
|
+
if not self.open_loop:
|
|
70
|
+
self.encoder.eval()
|
|
71
|
+
with torch.no_grad():
|
|
72
|
+
act, indices = self(img, return_indices=True)
|
|
73
|
+
act = act.squeeze().detach()
|
|
74
|
+
act[:-1] = 0
|
|
75
|
+
act[-1] = 1
|
|
76
|
+
|
|
77
|
+
self.neighbor_1_idx = indices[0][0]
|
|
78
|
+
|
|
79
|
+
action_tensor = torch.zeros(7)
|
|
80
|
+
action_tensor[-1] = 1
|
|
81
|
+
self.open_loop = True
|
|
82
|
+
return action_tensor, logs
|
|
83
|
+
else:
|
|
84
|
+
_, action = self.train_dataset[self.neighbor_1_idx + self.idx]
|
|
85
|
+
action_tensor = torch.tensor(action).squeeze()
|
|
86
|
+
self.idx += 1
|
|
87
|
+
return action_tensor, logs
|
|
88
|
+
else:
|
|
89
|
+
_, action = self.train_dataset[self.idx]
|
|
90
|
+
action_tensor = torch.tensor(action).squeeze()
|
|
91
|
+
self.idx += 1
|
|
92
|
+
return action_tensor, logs
|
|
93
|
+
|
|
94
|
+
def __call__(self, batch_images, k=None, return_indices=False):
|
|
95
|
+
if k is None:
|
|
96
|
+
k = self.k
|
|
97
|
+
|
|
98
|
+
all_distances = torch.zeros(
|
|
99
|
+
(batch_images[0].shape[0], self.representations.shape[0])
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
batch_rep = self.encoder(batch_images).squeeze(dim=1).detach().to(self.device)
|
|
103
|
+
dat_rep = self.representations.to(self.device)
|
|
104
|
+
all_distances = torch.cdist(batch_rep, dat_rep).to("cpu")
|
|
105
|
+
|
|
106
|
+
top_k_distances, indices = torch.topk(all_distances, k, dim=1, largest=False)
|
|
107
|
+
top_k_actions = self.actions[indices].to(self.device)
|
|
108
|
+
|
|
109
|
+
weights = self.dist_scale_func(top_k_distances).to(self.device)
|
|
110
|
+
|
|
111
|
+
pred = torch.sum(
|
|
112
|
+
top_k_actions * weights.unsqueeze(-1), dim=1
|
|
113
|
+
) # weighted average
|
|
114
|
+
|
|
115
|
+
if return_indices:
|
|
116
|
+
return pred, indices
|
|
117
|
+
|
|
118
|
+
return pred
|
|
119
|
+
|
|
120
|
+
def reset(self):
|
|
121
|
+
self.open_loop = False
|
|
122
|
+
self.idx = 0
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from scipy.spatial.transform import Rotation as R
|
|
4
|
+
from quaternion import (
|
|
5
|
+
as_rotation_matrix,
|
|
6
|
+
quaternion,
|
|
7
|
+
)
|
|
8
|
+
import torch
|
|
9
|
+
P = np.array([[-1, 0, 0, 0], [0, 0, -1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
|
|
10
|
+
|
|
11
|
+
def apply_permutation_transform(matrix):
|
|
12
|
+
return P @ matrix @ P.T
|
|
13
|
+
|
|
14
|
+
class SimpleReplay:
|
|
15
|
+
def __init__(self, poses_file_path, timeskip=8):
|
|
16
|
+
self.poses_file_path = poses_file_path
|
|
17
|
+
self.timeskip = timeskip
|
|
18
|
+
self.idx = 0
|
|
19
|
+
self.transforms = None
|
|
20
|
+
|
|
21
|
+
self.process_poses()
|
|
22
|
+
|
|
23
|
+
def to(self, device):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def eval(self):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def get_poses(self):
|
|
30
|
+
with open(self.poses_file_path, "r") as f:
|
|
31
|
+
lines = f.readlines()
|
|
32
|
+
poses, timestamps = [], []
|
|
33
|
+
for line in lines:
|
|
34
|
+
line_list = eval(line)
|
|
35
|
+
ts = int(line_list[0].split("<")[1].split(">")[0])
|
|
36
|
+
pose = np.array(line_list[1:])
|
|
37
|
+
timestamps.append(ts)
|
|
38
|
+
poses.append(pose)
|
|
39
|
+
timestamps = np.array(timestamps)
|
|
40
|
+
poses = np.array(poses)
|
|
41
|
+
|
|
42
|
+
return poses, timestamps
|
|
43
|
+
|
|
44
|
+
def process_poses(self):
|
|
45
|
+
quaternions = []
|
|
46
|
+
translations = []
|
|
47
|
+
init_pose = None
|
|
48
|
+
|
|
49
|
+
poses, timestamps = self.get_poses()
|
|
50
|
+
for pose in poses:
|
|
51
|
+
qx, qy, qz, qw, tx, ty, tz = pose
|
|
52
|
+
ext_matrix = np.eye(4)
|
|
53
|
+
ext_matrix[:3, :3] = as_rotation_matrix(quaternion(qw, qx, qy, qz))
|
|
54
|
+
ext_matrix[:3, 3] = tx, ty, tz
|
|
55
|
+
|
|
56
|
+
if init_pose is None:
|
|
57
|
+
init_pose = np.copy(ext_matrix)
|
|
58
|
+
relative_pose = np.linalg.inv(init_pose) @ ext_matrix
|
|
59
|
+
|
|
60
|
+
relative_pose = apply_permutation_transform(relative_pose)
|
|
61
|
+
translations.append(relative_pose[:3, -1])
|
|
62
|
+
quaternions.append(
|
|
63
|
+
R.from_matrix(relative_pose[:3, :3]).as_quat()
|
|
64
|
+
)
|
|
65
|
+
quats = np.array(quaternions)
|
|
66
|
+
translations = np.array(translations)
|
|
67
|
+
transforms = np.concatenate([translations, quats], axis=1)
|
|
68
|
+
|
|
69
|
+
self.transforms = transforms
|
|
70
|
+
|
|
71
|
+
def get_action(self, idx):
|
|
72
|
+
prior_translations, prior_rotations = self.transforms[idx, :3], self.transforms[idx, 3:]
|
|
73
|
+
next_translations, next_rotations = self.transforms[idx + self.timeskip, :3], self.transforms[idx + self.timeskip, 3:]
|
|
74
|
+
# Now, create the matrices.
|
|
75
|
+
prior_rot_matrices, next_rot_matrices = (
|
|
76
|
+
R.from_quat(prior_rotations).as_matrix(),
|
|
77
|
+
R.from_quat(next_rotations).as_matrix(),
|
|
78
|
+
)
|
|
79
|
+
# Now, compute the relative matrices.
|
|
80
|
+
prior_matrices = np.eye(4)
|
|
81
|
+
prior_matrices[:3, :3] = prior_rot_matrices
|
|
82
|
+
prior_matrices[:3, 3] = prior_translations
|
|
83
|
+
|
|
84
|
+
next_matrices = np.eye(4)
|
|
85
|
+
next_matrices[:3, :3] = next_rot_matrices
|
|
86
|
+
next_matrices[:3, 3] = next_translations
|
|
87
|
+
|
|
88
|
+
relative_transforms = np.matmul(np.linalg.inv(prior_matrices), next_matrices)
|
|
89
|
+
relative_translations = relative_transforms[:3, 3]
|
|
90
|
+
relative_rotations = R.from_matrix(relative_transforms[:3, :3]).as_rotvec()
|
|
91
|
+
|
|
92
|
+
gripper = 1.0
|
|
93
|
+
|
|
94
|
+
return np.concatenate([relative_translations, relative_rotations, [gripper]], dtype=np.float32)
|
|
95
|
+
|
|
96
|
+
def step(self, img, step_no):
|
|
97
|
+
start_idx = self.timeskip * step_no
|
|
98
|
+
if start_idx + self.timeskip >= len(self.transforms):
|
|
99
|
+
print("INDEX OUT OF BOUNDS")
|
|
100
|
+
action_tensor = torch.zeros(7)
|
|
101
|
+
action_tensor[-1] = 1
|
|
102
|
+
return action_tensor, {}
|
|
103
|
+
|
|
104
|
+
action_tensor = self.get_action(start_idx)
|
|
105
|
+
return torch.tensor(action_tensor), {}
|
|
106
|
+
|
|
107
|
+
def reset(self):
|
|
108
|
+
pass
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import websockets
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
|
11
|
+
|
|
12
|
+
logging.basicConfig(level=logging.INFO)
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def extract_molmo_points(molmo_output: str):
|
|
17
|
+
points = []
|
|
18
|
+
for match in re.finditer(
|
|
19
|
+
r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"',
|
|
20
|
+
molmo_output,
|
|
21
|
+
):
|
|
22
|
+
try:
|
|
23
|
+
p = np.array([float(match.group(1)), float(match.group(2))], dtype=np.float32)
|
|
24
|
+
except ValueError:
|
|
25
|
+
continue
|
|
26
|
+
|
|
27
|
+
if np.max(p) > 100:
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
points.append(p / 100.0)
|
|
31
|
+
|
|
32
|
+
return points
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Molmo:
|
|
36
|
+
def __init__(self, model_name="allenai/Molmo-7B-D-0924"):
|
|
37
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
38
|
+
model_name,
|
|
39
|
+
trust_remote_code=True,
|
|
40
|
+
torch_dtype="auto",
|
|
41
|
+
device_map="auto",
|
|
42
|
+
)
|
|
43
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
44
|
+
model_name,
|
|
45
|
+
trust_remote_code=True,
|
|
46
|
+
torch_dtype="auto",
|
|
47
|
+
device_map="auto",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def infer(self, rgb: np.ndarray, prompt: str) -> str:
|
|
51
|
+
image = Image.fromarray(rgb)
|
|
52
|
+
inputs = self.processor.process(images=[image], text=prompt)
|
|
53
|
+
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
|
|
54
|
+
|
|
55
|
+
with torch.inference_mode():
|
|
56
|
+
output = self.model.generate_from_batch(
|
|
57
|
+
inputs,
|
|
58
|
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
|
59
|
+
tokenizer=self.processor.tokenizer,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
gen_tokens = output[0, inputs["input_ids"].size(1):]
|
|
63
|
+
return self.processor.tokenizer.decode(gen_tokens, skip_special_tokens=True)
|
|
64
|
+
|
|
65
|
+
def infer_point(self, rgb: np.ndarray, prompt: str) -> np.ndarray:
|
|
66
|
+
text = self.infer(rgb, prompt)
|
|
67
|
+
log.info(f"Molmo output: {text}")
|
|
68
|
+
points = extract_molmo_points(text)
|
|
69
|
+
return points[0] if points else np.array([0.5, 0.5], dtype=np.float32)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MolmoWebSocketServer:
|
|
73
|
+
def __init__(self, host="0.0.0.0", port=8765):
|
|
74
|
+
self.host = host
|
|
75
|
+
self.port = port
|
|
76
|
+
self.molmo = Molmo()
|
|
77
|
+
|
|
78
|
+
async def handle_client(self, websocket):
|
|
79
|
+
client_id = id(websocket)
|
|
80
|
+
log.info(f"Client connected: {client_id}")
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
async for msg in websocket:
|
|
84
|
+
try:
|
|
85
|
+
req = json.loads(msg)
|
|
86
|
+
action = req.get("action")
|
|
87
|
+
|
|
88
|
+
if action != "infer_point":
|
|
89
|
+
raise ValueError("Only 'infer_point' action is supported")
|
|
90
|
+
|
|
91
|
+
rgb = np.array(req["rgb"], dtype=np.uint8)
|
|
92
|
+
|
|
93
|
+
if "object_name" in req:
|
|
94
|
+
prompt = f"Point to the center of the {req['object_name']}."
|
|
95
|
+
label = req["object_name"].replace(" ", "_")
|
|
96
|
+
elif "prompt" in req:
|
|
97
|
+
prompt = req["prompt"]
|
|
98
|
+
label = "custom_prompt"
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError("Provide either 'object_name' or 'prompt'")
|
|
101
|
+
|
|
102
|
+
point = self.molmo.infer_point(rgb, prompt)
|
|
103
|
+
|
|
104
|
+
resp = {
|
|
105
|
+
"status": "ok",
|
|
106
|
+
"action": "infer_point",
|
|
107
|
+
"point": point.tolist(),
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
except Exception as e:
|
|
111
|
+
log.exception("Request error")
|
|
112
|
+
resp = {"status": "error", "message": str(e)}
|
|
113
|
+
|
|
114
|
+
await websocket.send(json.dumps(resp))
|
|
115
|
+
|
|
116
|
+
except websockets.exceptions.ConnectionClosed:
|
|
117
|
+
log.info(f"Client disconnected: {client_id}")
|
|
118
|
+
|
|
119
|
+
async def start(self):
|
|
120
|
+
log.info(f"Starting Molmo server on ws://{self.host}:{self.port}")
|
|
121
|
+
async with websockets.serve(
|
|
122
|
+
self.handle_client,
|
|
123
|
+
self.host,
|
|
124
|
+
self.port,
|
|
125
|
+
max_size=50 * 1024 * 1024,
|
|
126
|
+
):
|
|
127
|
+
await asyncio.Future()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def main():
|
|
131
|
+
parser = argparse.ArgumentParser("Molmo Pointing Server")
|
|
132
|
+
parser.add_argument("--host", default="0.0.0.0")
|
|
133
|
+
parser.add_argument("--port", type=int, default=8765)
|
|
134
|
+
args = parser.parse_args()
|
|
135
|
+
|
|
136
|
+
server = MolmoWebSocketServer(
|
|
137
|
+
host=args.host,
|
|
138
|
+
port=args.port,
|
|
139
|
+
)
|
|
140
|
+
asyncio.run(server.start())
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
if __name__ == "__main__":
|
|
144
|
+
main()
|
baselines/rum/policy.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import cv2
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from scipy.spatial.transform import Rotation as R
|
|
6
|
+
import torchvision.transforms as T
|
|
7
|
+
from hydra import initialize, compose
|
|
8
|
+
import hydra
|
|
9
|
+
|
|
10
|
+
P = np.array(
|
|
11
|
+
[[-1, 0, 0, 0], [0, 0, -1, 0], [0, -1, 0, 0], [0, 0, 0, 1]], dtype=np.float32
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
def init_model_loss_fn(cfg):
|
|
15
|
+
device = cfg.device if torch.cuda.is_available() else "cpu"
|
|
16
|
+
model = hydra.utils.instantiate(cfg.model).to(device)
|
|
17
|
+
model_weight_pth = cfg.get("model_weight_pth")
|
|
18
|
+
|
|
19
|
+
if model_weight_pth is None:
|
|
20
|
+
raise ValueError("Model weight path is not specified in the config.")
|
|
21
|
+
|
|
22
|
+
checkpoint = torch.load(
|
|
23
|
+
model_weight_pth, map_location=device, weights_only=False
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
model.load_state_dict(checkpoint["model"])
|
|
28
|
+
except RuntimeError:
|
|
29
|
+
checkpoint["model"] = {
|
|
30
|
+
k.replace("_orig_mod.", ""): v for k, v in checkpoint["model"].items()
|
|
31
|
+
}
|
|
32
|
+
model.load_state_dict(checkpoint["model"])
|
|
33
|
+
loss_fn = hydra.utils.instantiate(cfg.loss_fn, model=model)
|
|
34
|
+
loss_fn.load_state_dict(checkpoint["loss_fn"])
|
|
35
|
+
loss_fn = loss_fn.to(device)
|
|
36
|
+
|
|
37
|
+
return model, loss_fn
|
|
38
|
+
|
|
39
|
+
class VectorizedBuffer:
|
|
40
|
+
def __init__(self, batch_size, buffer_size, act_dim, device):
|
|
41
|
+
self.batch_size = batch_size
|
|
42
|
+
self.buffer_size = buffer_size
|
|
43
|
+
self.act_dim = act_dim
|
|
44
|
+
self.device = device
|
|
45
|
+
self.image_buffer = None
|
|
46
|
+
self.goal_buffer = None
|
|
47
|
+
self.action_buffer = None
|
|
48
|
+
self.image_buffers_sizes = torch.zeros(batch_size, device=device)
|
|
49
|
+
self.goal_buffer_size = torch.zeros(batch_size, device=device)
|
|
50
|
+
self.action_buffers_sizes = torch.zeros(batch_size, device=device)
|
|
51
|
+
|
|
52
|
+
def add_image(self, new_images):
|
|
53
|
+
if self.image_buffer is None:
|
|
54
|
+
self.image_buffer = (
|
|
55
|
+
new_images.unsqueeze(1)
|
|
56
|
+
.repeat(1, self.buffer_size, 1, 1, 1)
|
|
57
|
+
.to(self.device)
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
for b in range(self.batch_size):
|
|
61
|
+
if self.image_buffers_sizes[b] == 0:
|
|
62
|
+
self.image_buffer[b] = (
|
|
63
|
+
new_images[b].unsqueeze(0).repeat(self.buffer_size, 1, 1, 1)
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
self.image_buffer[b] = torch.roll(
|
|
67
|
+
self.image_buffer[b], shifts=-1, dims=0
|
|
68
|
+
)
|
|
69
|
+
self.image_buffer[b, -1] = new_images[b]
|
|
70
|
+
self.image_buffers_sizes += 1
|
|
71
|
+
self.image_buffers_sizes = torch.clamp(
|
|
72
|
+
self.image_buffers_sizes, max=self.buffer_size
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def add_goal(self, new_goals):
|
|
76
|
+
if self.goal_buffer is None:
|
|
77
|
+
self.goal_buffer = (
|
|
78
|
+
new_goals.unsqueeze(1).repeat(1, self.buffer_size, 1).to(self.device)
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
for b in range(self.batch_size):
|
|
82
|
+
if self.goal_buffer_size[b] == 0:
|
|
83
|
+
self.goal_buffer[b] = (
|
|
84
|
+
new_goals[b].unsqueeze(0).repeat(self.buffer_size, 1)
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
self.goal_buffer[b] = torch.roll(
|
|
88
|
+
self.goal_buffer[b], shifts=-1, dims=0
|
|
89
|
+
)
|
|
90
|
+
self.goal_buffer[b, -1] = new_goals[b]
|
|
91
|
+
self.goal_buffer_size += 1
|
|
92
|
+
self.goal_buffer_size = torch.clamp(self.goal_buffer_size, max=self.buffer_size)
|
|
93
|
+
|
|
94
|
+
def reset(self, batch_indices):
|
|
95
|
+
self.image_buffers_sizes[batch_indices] = 0
|
|
96
|
+
self.goal_buffer_size[batch_indices] = 0
|
|
97
|
+
self.action_buffers_sizes[batch_indices] = 0
|
|
98
|
+
if self.image_buffer is not None:
|
|
99
|
+
self.image_buffer[batch_indices] = torch.zeros_like(
|
|
100
|
+
self.image_buffer[batch_indices]
|
|
101
|
+
)
|
|
102
|
+
if self.goal_buffer is not None:
|
|
103
|
+
self.goal_buffer[batch_indices] = torch.zeros_like(
|
|
104
|
+
self.goal_buffer[batch_indices]
|
|
105
|
+
)
|
|
106
|
+
if self.action_buffer is not None:
|
|
107
|
+
self.action_buffer[batch_indices] = torch.zeros_like(
|
|
108
|
+
self.action_buffer[batch_indices]
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def add_action(self, new_actions):
|
|
112
|
+
B = new_actions.shape[0]
|
|
113
|
+
if self.action_buffer is None:
|
|
114
|
+
self.action_buffer = torch.zeros(
|
|
115
|
+
B, self.buffer_size - 1, self.act_dim, device=self.device
|
|
116
|
+
)
|
|
117
|
+
self.action_buffers_sizes = torch.zeros(B, device=self.device)
|
|
118
|
+
for b in range(B):
|
|
119
|
+
if self.action_buffers_sizes[b] != 0:
|
|
120
|
+
self.action_buffer[b] = torch.roll(
|
|
121
|
+
self.action_buffer[b], shifts=-1, dims=0
|
|
122
|
+
)
|
|
123
|
+
self.action_buffer[b, -1] = new_actions[b]
|
|
124
|
+
self.action_buffers_sizes += 1
|
|
125
|
+
|
|
126
|
+
def get_input_sequence(self):
|
|
127
|
+
B = self.image_buffer.shape[0]
|
|
128
|
+
if self.action_buffer is None:
|
|
129
|
+
action_buffer = torch.zeros(
|
|
130
|
+
B, self.buffer_size - 1, self.act_dim, device=self.device
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
action_buffer = self.action_buffer
|
|
134
|
+
base_act = torch.zeros(B, 1, self.act_dim, device=self.device)
|
|
135
|
+
act_seq = torch.cat([action_buffer, base_act], dim=1)
|
|
136
|
+
if self.goal_buffer is None:
|
|
137
|
+
goal_seq = None
|
|
138
|
+
else:
|
|
139
|
+
goal_seq = torch.stack([goal for goal in self.goal_buffer]).to(
|
|
140
|
+
dtype=torch.float32
|
|
141
|
+
)
|
|
142
|
+
return self.image_buffer, goal_seq, act_seq
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def unwrap_model(model):
|
|
146
|
+
if isinstance(
|
|
147
|
+
model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)
|
|
148
|
+
):
|
|
149
|
+
return model.module
|
|
150
|
+
return model
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Policy:
|
|
154
|
+
def __init__(self, model_path=None, device="cpu", model=None, loss_fn=None):
|
|
155
|
+
if model_path is None and (model is None or loss_fn is None):
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Either model_path or both model and loss_fn must be provided."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if model is None or loss_fn is None:
|
|
161
|
+
with initialize(config_path="configs", version_base=None):
|
|
162
|
+
cfg = compose(config_name="run_vqbet")
|
|
163
|
+
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
|
|
164
|
+
ckpt_cfg = checkpoint["cfg"]
|
|
165
|
+
del checkpoint
|
|
166
|
+
|
|
167
|
+
cfg.goal_dim = ckpt_cfg["loss_fn"]["goal_dim"]
|
|
168
|
+
|
|
169
|
+
cfg.gpt_input_dim = ckpt_cfg["loss_fn"]["gpt_model"]["config"]["input_dim"]
|
|
170
|
+
cfg.loss_fn.gpt_model.config.n_layer = ckpt_cfg["loss_fn"]["gpt_model"][
|
|
171
|
+
"config"
|
|
172
|
+
]["n_layer"]
|
|
173
|
+
cfg.loss_fn.gpt_model.config.n_head = ckpt_cfg["loss_fn"]["gpt_model"][
|
|
174
|
+
"config"
|
|
175
|
+
]["n_head"]
|
|
176
|
+
cfg.loss_fn.gpt_model.config.n_embd = ckpt_cfg["loss_fn"]["gpt_model"][
|
|
177
|
+
"config"
|
|
178
|
+
]["n_embd"]
|
|
179
|
+
|
|
180
|
+
cfg.vqvae_n_embed = ckpt_cfg["vqvae_n_embed"]
|
|
181
|
+
cfg.model_weight_pth = model_path
|
|
182
|
+
|
|
183
|
+
cfg.device = device
|
|
184
|
+
model, loss_fn = init_model_loss_fn(cfg)
|
|
185
|
+
|
|
186
|
+
self.to_tensor = T.ToTensor()
|
|
187
|
+
self.model = unwrap_model(model)
|
|
188
|
+
self.loss_fn = unwrap_model(loss_fn)
|
|
189
|
+
self.buffer_size = self.loss_fn._vqbet.obs_window_size
|
|
190
|
+
self.device = device
|
|
191
|
+
|
|
192
|
+
goal_dim = self.loss_fn.goal_dim
|
|
193
|
+
self.condition = f"{goal_dim}d"
|
|
194
|
+
if goal_dim == 0:
|
|
195
|
+
self.condition = None
|
|
196
|
+
|
|
197
|
+
valid_conditions = ("4d", "3d", "2d")
|
|
198
|
+
if self.condition is not None and self.condition not in valid_conditions:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"'condition' must be one of {valid_conditions}, got '{self.condition}'"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.model.eval()
|
|
204
|
+
self.loss_fn.eval()
|
|
205
|
+
|
|
206
|
+
self.vectorized_buffer = None
|
|
207
|
+
self.act_dim = 7
|
|
208
|
+
|
|
209
|
+
self.rot_yx_90 = (
|
|
210
|
+
R.from_euler("y", 90, degrees=True).as_matrix()
|
|
211
|
+
@ R.from_euler("x", 90, degrees=True).as_matrix()
|
|
212
|
+
)
|
|
213
|
+
self.Tyx = np.eye(4, dtype=np.float32)
|
|
214
|
+
self.Tyx[:3, :3] = self.rot_yx_90
|
|
215
|
+
|
|
216
|
+
rot_z_90 = R.from_euler("z", 90, degrees=True).as_matrix()
|
|
217
|
+
Tz = np.eye(4, dtype=np.float32)
|
|
218
|
+
Tz[:3, :3] = rot_z_90
|
|
219
|
+
|
|
220
|
+
M = self.Tyx @ Tz @ P.T
|
|
221
|
+
M_inv = M.T
|
|
222
|
+
|
|
223
|
+
self.M_t = torch.from_numpy(M).to(device)
|
|
224
|
+
self.M_inv_t = torch.from_numpy(M_inv).to(device)
|
|
225
|
+
|
|
226
|
+
def reset(self, indicies=None):
|
|
227
|
+
if indicies is None:
|
|
228
|
+
self.vectorized_buffer = None
|
|
229
|
+
else:
|
|
230
|
+
self.vectorized_buffer.reset(indicies)
|
|
231
|
+
|
|
232
|
+
def process_image(self, img):
|
|
233
|
+
if isinstance(img, np.ndarray):
|
|
234
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
235
|
+
img = Image.fromarray(img)
|
|
236
|
+
return self.to_tensor(img)
|
|
237
|
+
|
|
238
|
+
def process_image_batch(self, imgs):
|
|
239
|
+
if isinstance(imgs, np.ndarray):
|
|
240
|
+
if imgs.dtype == np.uint8:
|
|
241
|
+
imgs_tensor = torch.from_numpy(imgs).permute(0, 3, 1, 2).float() / 255.0
|
|
242
|
+
else:
|
|
243
|
+
imgs_tensor = torch.from_numpy(imgs).permute(0, 3, 1, 2).float()
|
|
244
|
+
return imgs_tensor.to(self.device)
|
|
245
|
+
else:
|
|
246
|
+
return torch.stack([self.process_image(imgs[i]) for i in range(len(imgs))]).to(self.device)
|
|
247
|
+
|
|
248
|
+
def infer(self, observations):
|
|
249
|
+
obs = observations["rgb_ego"]
|
|
250
|
+
if self.condition == "2d":
|
|
251
|
+
goal = observations["object_2d_position"]
|
|
252
|
+
elif self.condition == "3d":
|
|
253
|
+
goal = observations["object_3d_position"]
|
|
254
|
+
else:
|
|
255
|
+
goal = None
|
|
256
|
+
|
|
257
|
+
if len(obs.shape) == 3:
|
|
258
|
+
obs = np.expand_dims(obs, axis=0)
|
|
259
|
+
if goal is not None:
|
|
260
|
+
goal = np.expand_dims(goal, axis=0)
|
|
261
|
+
|
|
262
|
+
B = obs.shape[0]
|
|
263
|
+
processed_images = self.process_image_batch(obs)
|
|
264
|
+
|
|
265
|
+
if self.vectorized_buffer is None:
|
|
266
|
+
image_shape = processed_images.shape[1:]
|
|
267
|
+
self.vectorized_buffer = VectorizedBuffer(
|
|
268
|
+
batch_size=B,
|
|
269
|
+
buffer_size=self.buffer_size,
|
|
270
|
+
act_dim=self.act_dim,
|
|
271
|
+
device=self.device,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
self.vectorized_buffer.add_image(processed_images)
|
|
275
|
+
|
|
276
|
+
if self.condition is not None:
|
|
277
|
+
processed_goals = torch.from_numpy(goal).to(self.device, dtype=torch.float32).view(B, -1)
|
|
278
|
+
self.vectorized_buffer.add_goal(processed_goals)
|
|
279
|
+
|
|
280
|
+
img_seq, goal_seq, act_seq = self.vectorized_buffer.get_input_sequence()
|
|
281
|
+
|
|
282
|
+
with torch.no_grad():
|
|
283
|
+
model_input = (img_seq, goal_seq, act_seq)
|
|
284
|
+
model_output = self.model(model_input)
|
|
285
|
+
action_tensors, logs = self.loss_fn.step(
|
|
286
|
+
model_input, model_output, return_all=True
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
action_tensors = action_tensors.squeeze(1).to(self.device)
|
|
290
|
+
self.vectorized_buffer.add_action(action_tensors)
|
|
291
|
+
action_tensors = action_tensors.cpu().numpy()
|
|
292
|
+
|
|
293
|
+
return action_tensors
|