egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,122 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as T
4
+ from tqdm import tqdm
5
+
6
+
7
+ class OpenLoopReplay(nn.Module):
8
+ def __init__(self, encoder, k=5, enc_weight_pth=None, use_vinn=False, cfg=None):
9
+ super().__init__()
10
+ self.encoder = encoder
11
+ self.cfg = cfg
12
+ self.use_vinn = use_vinn
13
+
14
+ self.k = k
15
+
16
+ if enc_weight_pth is not None:
17
+ self.encoder.load_state_dict(
18
+ torch.load(enc_weight_pth, map_location="cpu")["model"]
19
+ )
20
+
21
+ self.representations = None
22
+ self.actions = None
23
+ self.imgs = None
24
+ softmax = nn.Softmax(dim=1)
25
+ self.dist_scale_func = lambda x: (softmax(-x))
26
+ self.encoder.eval()
27
+ self.device = "cpu"
28
+ self.encoder.to(self.device)
29
+ self.img_transform = T.Resize((256, 256), antialias=True)
30
+
31
+ self.open_loop = False
32
+ self.idx = 0
33
+
34
+ def to(self, device):
35
+ self.device = device
36
+ self.encoder.to(device)
37
+
38
+ return super().to(device)
39
+
40
+ def set_dataset(self, dataloader):
41
+ self.train_dataset = dataloader.dataset
42
+ if self.use_vinn:
43
+ for i, (image, label) in tqdm(enumerate(dataloader)):
44
+ image = image.float() / 255.0
45
+ image = image.to(self.device)
46
+ label = torch.Tensor(label).to("cpu").detach().squeeze()
47
+
48
+ x = (image, label)
49
+ representation = self.encoder(x).to("cpu").detach().squeeze(dim=1)
50
+ if self.representations is None:
51
+ self.representations = representation
52
+ self.actions = label
53
+ image = image.to("cpu").detach().numpy()
54
+ self.imgs = list(image)
55
+ else:
56
+ self.representations = torch.cat(
57
+ (self.representations, representation), 0
58
+ )
59
+ self.actions = torch.cat((self.actions, label), 0)
60
+ image = image.to("cpu").detach().numpy()
61
+ self.imgs.extend(list(image))
62
+
63
+ def step(self, img, **kwargs):
64
+ logs = {}
65
+ print(self.idx)
66
+ print(len(self.train_dataset))
67
+ if self.use_vinn:
68
+ normalized_image = self.img_transform(img[0].squeeze(0))
69
+ if not self.open_loop:
70
+ self.encoder.eval()
71
+ with torch.no_grad():
72
+ act, indices = self(img, return_indices=True)
73
+ act = act.squeeze().detach()
74
+ act[:-1] = 0
75
+ act[-1] = 1
76
+
77
+ self.neighbor_1_idx = indices[0][0]
78
+
79
+ action_tensor = torch.zeros(7)
80
+ action_tensor[-1] = 1
81
+ self.open_loop = True
82
+ return action_tensor, logs
83
+ else:
84
+ _, action = self.train_dataset[self.neighbor_1_idx + self.idx]
85
+ action_tensor = torch.tensor(action).squeeze()
86
+ self.idx += 1
87
+ return action_tensor, logs
88
+ else:
89
+ _, action = self.train_dataset[self.idx]
90
+ action_tensor = torch.tensor(action).squeeze()
91
+ self.idx += 1
92
+ return action_tensor, logs
93
+
94
+ def __call__(self, batch_images, k=None, return_indices=False):
95
+ if k is None:
96
+ k = self.k
97
+
98
+ all_distances = torch.zeros(
99
+ (batch_images[0].shape[0], self.representations.shape[0])
100
+ )
101
+
102
+ batch_rep = self.encoder(batch_images).squeeze(dim=1).detach().to(self.device)
103
+ dat_rep = self.representations.to(self.device)
104
+ all_distances = torch.cdist(batch_rep, dat_rep).to("cpu")
105
+
106
+ top_k_distances, indices = torch.topk(all_distances, k, dim=1, largest=False)
107
+ top_k_actions = self.actions[indices].to(self.device)
108
+
109
+ weights = self.dist_scale_func(top_k_distances).to(self.device)
110
+
111
+ pred = torch.sum(
112
+ top_k_actions * weights.unsqueeze(-1), dim=1
113
+ ) # weighted average
114
+
115
+ if return_indices:
116
+ return pred, indices
117
+
118
+ return pred
119
+
120
+ def reset(self):
121
+ self.open_loop = False
122
+ self.idx = 0
@@ -0,0 +1,108 @@
1
+ import os
2
+ import numpy as np
3
+ from scipy.spatial.transform import Rotation as R
4
+ from quaternion import (
5
+ as_rotation_matrix,
6
+ quaternion,
7
+ )
8
+ import torch
9
+ P = np.array([[-1, 0, 0, 0], [0, 0, -1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
10
+
11
+ def apply_permutation_transform(matrix):
12
+ return P @ matrix @ P.T
13
+
14
+ class SimpleReplay:
15
+ def __init__(self, poses_file_path, timeskip=8):
16
+ self.poses_file_path = poses_file_path
17
+ self.timeskip = timeskip
18
+ self.idx = 0
19
+ self.transforms = None
20
+
21
+ self.process_poses()
22
+
23
+ def to(self, device):
24
+ pass
25
+
26
+ def eval(self):
27
+ pass
28
+
29
+ def get_poses(self):
30
+ with open(self.poses_file_path, "r") as f:
31
+ lines = f.readlines()
32
+ poses, timestamps = [], []
33
+ for line in lines:
34
+ line_list = eval(line)
35
+ ts = int(line_list[0].split("<")[1].split(">")[0])
36
+ pose = np.array(line_list[1:])
37
+ timestamps.append(ts)
38
+ poses.append(pose)
39
+ timestamps = np.array(timestamps)
40
+ poses = np.array(poses)
41
+
42
+ return poses, timestamps
43
+
44
+ def process_poses(self):
45
+ quaternions = []
46
+ translations = []
47
+ init_pose = None
48
+
49
+ poses, timestamps = self.get_poses()
50
+ for pose in poses:
51
+ qx, qy, qz, qw, tx, ty, tz = pose
52
+ ext_matrix = np.eye(4)
53
+ ext_matrix[:3, :3] = as_rotation_matrix(quaternion(qw, qx, qy, qz))
54
+ ext_matrix[:3, 3] = tx, ty, tz
55
+
56
+ if init_pose is None:
57
+ init_pose = np.copy(ext_matrix)
58
+ relative_pose = np.linalg.inv(init_pose) @ ext_matrix
59
+
60
+ relative_pose = apply_permutation_transform(relative_pose)
61
+ translations.append(relative_pose[:3, -1])
62
+ quaternions.append(
63
+ R.from_matrix(relative_pose[:3, :3]).as_quat()
64
+ )
65
+ quats = np.array(quaternions)
66
+ translations = np.array(translations)
67
+ transforms = np.concatenate([translations, quats], axis=1)
68
+
69
+ self.transforms = transforms
70
+
71
+ def get_action(self, idx):
72
+ prior_translations, prior_rotations = self.transforms[idx, :3], self.transforms[idx, 3:]
73
+ next_translations, next_rotations = self.transforms[idx + self.timeskip, :3], self.transforms[idx + self.timeskip, 3:]
74
+ # Now, create the matrices.
75
+ prior_rot_matrices, next_rot_matrices = (
76
+ R.from_quat(prior_rotations).as_matrix(),
77
+ R.from_quat(next_rotations).as_matrix(),
78
+ )
79
+ # Now, compute the relative matrices.
80
+ prior_matrices = np.eye(4)
81
+ prior_matrices[:3, :3] = prior_rot_matrices
82
+ prior_matrices[:3, 3] = prior_translations
83
+
84
+ next_matrices = np.eye(4)
85
+ next_matrices[:3, :3] = next_rot_matrices
86
+ next_matrices[:3, 3] = next_translations
87
+
88
+ relative_transforms = np.matmul(np.linalg.inv(prior_matrices), next_matrices)
89
+ relative_translations = relative_transforms[:3, 3]
90
+ relative_rotations = R.from_matrix(relative_transforms[:3, :3]).as_rotvec()
91
+
92
+ gripper = 1.0
93
+
94
+ return np.concatenate([relative_translations, relative_rotations, [gripper]], dtype=np.float32)
95
+
96
+ def step(self, img, step_no):
97
+ start_idx = self.timeskip * step_no
98
+ if start_idx + self.timeskip >= len(self.transforms):
99
+ print("INDEX OUT OF BOUNDS")
100
+ action_tensor = torch.zeros(7)
101
+ action_tensor[-1] = 1
102
+ return action_tensor, {}
103
+
104
+ action_tensor = self.get_action(start_idx)
105
+ return torch.tensor(action_tensor), {}
106
+
107
+ def reset(self):
108
+ pass
@@ -0,0 +1,144 @@
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import logging
5
+ import re
6
+ import numpy as np
7
+ import torch
8
+ import websockets
9
+ from PIL import Image
10
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def extract_molmo_points(molmo_output: str):
17
+ points = []
18
+ for match in re.finditer(
19
+ r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"',
20
+ molmo_output,
21
+ ):
22
+ try:
23
+ p = np.array([float(match.group(1)), float(match.group(2))], dtype=np.float32)
24
+ except ValueError:
25
+ continue
26
+
27
+ if np.max(p) > 100:
28
+ continue
29
+
30
+ points.append(p / 100.0)
31
+
32
+ return points
33
+
34
+
35
+ class Molmo:
36
+ def __init__(self, model_name="allenai/Molmo-7B-D-0924"):
37
+ self.processor = AutoProcessor.from_pretrained(
38
+ model_name,
39
+ trust_remote_code=True,
40
+ torch_dtype="auto",
41
+ device_map="auto",
42
+ )
43
+ self.model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ trust_remote_code=True,
46
+ torch_dtype="auto",
47
+ device_map="auto",
48
+ )
49
+
50
+ def infer(self, rgb: np.ndarray, prompt: str) -> str:
51
+ image = Image.fromarray(rgb)
52
+ inputs = self.processor.process(images=[image], text=prompt)
53
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
54
+
55
+ with torch.inference_mode():
56
+ output = self.model.generate_from_batch(
57
+ inputs,
58
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
59
+ tokenizer=self.processor.tokenizer,
60
+ )
61
+
62
+ gen_tokens = output[0, inputs["input_ids"].size(1):]
63
+ return self.processor.tokenizer.decode(gen_tokens, skip_special_tokens=True)
64
+
65
+ def infer_point(self, rgb: np.ndarray, prompt: str) -> np.ndarray:
66
+ text = self.infer(rgb, prompt)
67
+ log.info(f"Molmo output: {text}")
68
+ points = extract_molmo_points(text)
69
+ return points[0] if points else np.array([0.5, 0.5], dtype=np.float32)
70
+
71
+
72
+ class MolmoWebSocketServer:
73
+ def __init__(self, host="0.0.0.0", port=8765):
74
+ self.host = host
75
+ self.port = port
76
+ self.molmo = Molmo()
77
+
78
+ async def handle_client(self, websocket):
79
+ client_id = id(websocket)
80
+ log.info(f"Client connected: {client_id}")
81
+
82
+ try:
83
+ async for msg in websocket:
84
+ try:
85
+ req = json.loads(msg)
86
+ action = req.get("action")
87
+
88
+ if action != "infer_point":
89
+ raise ValueError("Only 'infer_point' action is supported")
90
+
91
+ rgb = np.array(req["rgb"], dtype=np.uint8)
92
+
93
+ if "object_name" in req:
94
+ prompt = f"Point to the center of the {req['object_name']}."
95
+ label = req["object_name"].replace(" ", "_")
96
+ elif "prompt" in req:
97
+ prompt = req["prompt"]
98
+ label = "custom_prompt"
99
+ else:
100
+ raise ValueError("Provide either 'object_name' or 'prompt'")
101
+
102
+ point = self.molmo.infer_point(rgb, prompt)
103
+
104
+ resp = {
105
+ "status": "ok",
106
+ "action": "infer_point",
107
+ "point": point.tolist(),
108
+ }
109
+
110
+ except Exception as e:
111
+ log.exception("Request error")
112
+ resp = {"status": "error", "message": str(e)}
113
+
114
+ await websocket.send(json.dumps(resp))
115
+
116
+ except websockets.exceptions.ConnectionClosed:
117
+ log.info(f"Client disconnected: {client_id}")
118
+
119
+ async def start(self):
120
+ log.info(f"Starting Molmo server on ws://{self.host}:{self.port}")
121
+ async with websockets.serve(
122
+ self.handle_client,
123
+ self.host,
124
+ self.port,
125
+ max_size=50 * 1024 * 1024,
126
+ ):
127
+ await asyncio.Future()
128
+
129
+
130
+ def main():
131
+ parser = argparse.ArgumentParser("Molmo Pointing Server")
132
+ parser.add_argument("--host", default="0.0.0.0")
133
+ parser.add_argument("--port", type=int, default=8765)
134
+ args = parser.parse_args()
135
+
136
+ server = MolmoWebSocketServer(
137
+ host=args.host,
138
+ port=args.port,
139
+ )
140
+ asyncio.run(server.start())
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
@@ -0,0 +1,293 @@
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ from scipy.spatial.transform import Rotation as R
6
+ import torchvision.transforms as T
7
+ from hydra import initialize, compose
8
+ import hydra
9
+
10
+ P = np.array(
11
+ [[-1, 0, 0, 0], [0, 0, -1, 0], [0, -1, 0, 0], [0, 0, 0, 1]], dtype=np.float32
12
+ )
13
+
14
+ def init_model_loss_fn(cfg):
15
+ device = cfg.device if torch.cuda.is_available() else "cpu"
16
+ model = hydra.utils.instantiate(cfg.model).to(device)
17
+ model_weight_pth = cfg.get("model_weight_pth")
18
+
19
+ if model_weight_pth is None:
20
+ raise ValueError("Model weight path is not specified in the config.")
21
+
22
+ checkpoint = torch.load(
23
+ model_weight_pth, map_location=device, weights_only=False
24
+ )
25
+
26
+ try:
27
+ model.load_state_dict(checkpoint["model"])
28
+ except RuntimeError:
29
+ checkpoint["model"] = {
30
+ k.replace("_orig_mod.", ""): v for k, v in checkpoint["model"].items()
31
+ }
32
+ model.load_state_dict(checkpoint["model"])
33
+ loss_fn = hydra.utils.instantiate(cfg.loss_fn, model=model)
34
+ loss_fn.load_state_dict(checkpoint["loss_fn"])
35
+ loss_fn = loss_fn.to(device)
36
+
37
+ return model, loss_fn
38
+
39
+ class VectorizedBuffer:
40
+ def __init__(self, batch_size, buffer_size, act_dim, device):
41
+ self.batch_size = batch_size
42
+ self.buffer_size = buffer_size
43
+ self.act_dim = act_dim
44
+ self.device = device
45
+ self.image_buffer = None
46
+ self.goal_buffer = None
47
+ self.action_buffer = None
48
+ self.image_buffers_sizes = torch.zeros(batch_size, device=device)
49
+ self.goal_buffer_size = torch.zeros(batch_size, device=device)
50
+ self.action_buffers_sizes = torch.zeros(batch_size, device=device)
51
+
52
+ def add_image(self, new_images):
53
+ if self.image_buffer is None:
54
+ self.image_buffer = (
55
+ new_images.unsqueeze(1)
56
+ .repeat(1, self.buffer_size, 1, 1, 1)
57
+ .to(self.device)
58
+ )
59
+ else:
60
+ for b in range(self.batch_size):
61
+ if self.image_buffers_sizes[b] == 0:
62
+ self.image_buffer[b] = (
63
+ new_images[b].unsqueeze(0).repeat(self.buffer_size, 1, 1, 1)
64
+ )
65
+ else:
66
+ self.image_buffer[b] = torch.roll(
67
+ self.image_buffer[b], shifts=-1, dims=0
68
+ )
69
+ self.image_buffer[b, -1] = new_images[b]
70
+ self.image_buffers_sizes += 1
71
+ self.image_buffers_sizes = torch.clamp(
72
+ self.image_buffers_sizes, max=self.buffer_size
73
+ )
74
+
75
+ def add_goal(self, new_goals):
76
+ if self.goal_buffer is None:
77
+ self.goal_buffer = (
78
+ new_goals.unsqueeze(1).repeat(1, self.buffer_size, 1).to(self.device)
79
+ )
80
+ else:
81
+ for b in range(self.batch_size):
82
+ if self.goal_buffer_size[b] == 0:
83
+ self.goal_buffer[b] = (
84
+ new_goals[b].unsqueeze(0).repeat(self.buffer_size, 1)
85
+ )
86
+ else:
87
+ self.goal_buffer[b] = torch.roll(
88
+ self.goal_buffer[b], shifts=-1, dims=0
89
+ )
90
+ self.goal_buffer[b, -1] = new_goals[b]
91
+ self.goal_buffer_size += 1
92
+ self.goal_buffer_size = torch.clamp(self.goal_buffer_size, max=self.buffer_size)
93
+
94
+ def reset(self, batch_indices):
95
+ self.image_buffers_sizes[batch_indices] = 0
96
+ self.goal_buffer_size[batch_indices] = 0
97
+ self.action_buffers_sizes[batch_indices] = 0
98
+ if self.image_buffer is not None:
99
+ self.image_buffer[batch_indices] = torch.zeros_like(
100
+ self.image_buffer[batch_indices]
101
+ )
102
+ if self.goal_buffer is not None:
103
+ self.goal_buffer[batch_indices] = torch.zeros_like(
104
+ self.goal_buffer[batch_indices]
105
+ )
106
+ if self.action_buffer is not None:
107
+ self.action_buffer[batch_indices] = torch.zeros_like(
108
+ self.action_buffer[batch_indices]
109
+ )
110
+
111
+ def add_action(self, new_actions):
112
+ B = new_actions.shape[0]
113
+ if self.action_buffer is None:
114
+ self.action_buffer = torch.zeros(
115
+ B, self.buffer_size - 1, self.act_dim, device=self.device
116
+ )
117
+ self.action_buffers_sizes = torch.zeros(B, device=self.device)
118
+ for b in range(B):
119
+ if self.action_buffers_sizes[b] != 0:
120
+ self.action_buffer[b] = torch.roll(
121
+ self.action_buffer[b], shifts=-1, dims=0
122
+ )
123
+ self.action_buffer[b, -1] = new_actions[b]
124
+ self.action_buffers_sizes += 1
125
+
126
+ def get_input_sequence(self):
127
+ B = self.image_buffer.shape[0]
128
+ if self.action_buffer is None:
129
+ action_buffer = torch.zeros(
130
+ B, self.buffer_size - 1, self.act_dim, device=self.device
131
+ )
132
+ else:
133
+ action_buffer = self.action_buffer
134
+ base_act = torch.zeros(B, 1, self.act_dim, device=self.device)
135
+ act_seq = torch.cat([action_buffer, base_act], dim=1)
136
+ if self.goal_buffer is None:
137
+ goal_seq = None
138
+ else:
139
+ goal_seq = torch.stack([goal for goal in self.goal_buffer]).to(
140
+ dtype=torch.float32
141
+ )
142
+ return self.image_buffer, goal_seq, act_seq
143
+
144
+
145
+ def unwrap_model(model):
146
+ if isinstance(
147
+ model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)
148
+ ):
149
+ return model.module
150
+ return model
151
+
152
+
153
+ class Policy:
154
+ def __init__(self, model_path=None, device="cpu", model=None, loss_fn=None):
155
+ if model_path is None and (model is None or loss_fn is None):
156
+ raise ValueError(
157
+ "Either model_path or both model and loss_fn must be provided."
158
+ )
159
+
160
+ if model is None or loss_fn is None:
161
+ with initialize(config_path="configs", version_base=None):
162
+ cfg = compose(config_name="run_vqbet")
163
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
164
+ ckpt_cfg = checkpoint["cfg"]
165
+ del checkpoint
166
+
167
+ cfg.goal_dim = ckpt_cfg["loss_fn"]["goal_dim"]
168
+
169
+ cfg.gpt_input_dim = ckpt_cfg["loss_fn"]["gpt_model"]["config"]["input_dim"]
170
+ cfg.loss_fn.gpt_model.config.n_layer = ckpt_cfg["loss_fn"]["gpt_model"][
171
+ "config"
172
+ ]["n_layer"]
173
+ cfg.loss_fn.gpt_model.config.n_head = ckpt_cfg["loss_fn"]["gpt_model"][
174
+ "config"
175
+ ]["n_head"]
176
+ cfg.loss_fn.gpt_model.config.n_embd = ckpt_cfg["loss_fn"]["gpt_model"][
177
+ "config"
178
+ ]["n_embd"]
179
+
180
+ cfg.vqvae_n_embed = ckpt_cfg["vqvae_n_embed"]
181
+ cfg.model_weight_pth = model_path
182
+
183
+ cfg.device = device
184
+ model, loss_fn = init_model_loss_fn(cfg)
185
+
186
+ self.to_tensor = T.ToTensor()
187
+ self.model = unwrap_model(model)
188
+ self.loss_fn = unwrap_model(loss_fn)
189
+ self.buffer_size = self.loss_fn._vqbet.obs_window_size
190
+ self.device = device
191
+
192
+ goal_dim = self.loss_fn.goal_dim
193
+ self.condition = f"{goal_dim}d"
194
+ if goal_dim == 0:
195
+ self.condition = None
196
+
197
+ valid_conditions = ("4d", "3d", "2d")
198
+ if self.condition is not None and self.condition not in valid_conditions:
199
+ raise ValueError(
200
+ f"'condition' must be one of {valid_conditions}, got '{self.condition}'"
201
+ )
202
+
203
+ self.model.eval()
204
+ self.loss_fn.eval()
205
+
206
+ self.vectorized_buffer = None
207
+ self.act_dim = 7
208
+
209
+ self.rot_yx_90 = (
210
+ R.from_euler("y", 90, degrees=True).as_matrix()
211
+ @ R.from_euler("x", 90, degrees=True).as_matrix()
212
+ )
213
+ self.Tyx = np.eye(4, dtype=np.float32)
214
+ self.Tyx[:3, :3] = self.rot_yx_90
215
+
216
+ rot_z_90 = R.from_euler("z", 90, degrees=True).as_matrix()
217
+ Tz = np.eye(4, dtype=np.float32)
218
+ Tz[:3, :3] = rot_z_90
219
+
220
+ M = self.Tyx @ Tz @ P.T
221
+ M_inv = M.T
222
+
223
+ self.M_t = torch.from_numpy(M).to(device)
224
+ self.M_inv_t = torch.from_numpy(M_inv).to(device)
225
+
226
+ def reset(self, indicies=None):
227
+ if indicies is None:
228
+ self.vectorized_buffer = None
229
+ else:
230
+ self.vectorized_buffer.reset(indicies)
231
+
232
+ def process_image(self, img):
233
+ if isinstance(img, np.ndarray):
234
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
235
+ img = Image.fromarray(img)
236
+ return self.to_tensor(img)
237
+
238
+ def process_image_batch(self, imgs):
239
+ if isinstance(imgs, np.ndarray):
240
+ if imgs.dtype == np.uint8:
241
+ imgs_tensor = torch.from_numpy(imgs).permute(0, 3, 1, 2).float() / 255.0
242
+ else:
243
+ imgs_tensor = torch.from_numpy(imgs).permute(0, 3, 1, 2).float()
244
+ return imgs_tensor.to(self.device)
245
+ else:
246
+ return torch.stack([self.process_image(imgs[i]) for i in range(len(imgs))]).to(self.device)
247
+
248
+ def infer(self, observations):
249
+ obs = observations["rgb_ego"]
250
+ if self.condition == "2d":
251
+ goal = observations["object_2d_position"]
252
+ elif self.condition == "3d":
253
+ goal = observations["object_3d_position"]
254
+ else:
255
+ goal = None
256
+
257
+ if len(obs.shape) == 3:
258
+ obs = np.expand_dims(obs, axis=0)
259
+ if goal is not None:
260
+ goal = np.expand_dims(goal, axis=0)
261
+
262
+ B = obs.shape[0]
263
+ processed_images = self.process_image_batch(obs)
264
+
265
+ if self.vectorized_buffer is None:
266
+ image_shape = processed_images.shape[1:]
267
+ self.vectorized_buffer = VectorizedBuffer(
268
+ batch_size=B,
269
+ buffer_size=self.buffer_size,
270
+ act_dim=self.act_dim,
271
+ device=self.device,
272
+ )
273
+
274
+ self.vectorized_buffer.add_image(processed_images)
275
+
276
+ if self.condition is not None:
277
+ processed_goals = torch.from_numpy(goal).to(self.device, dtype=torch.float32).view(B, -1)
278
+ self.vectorized_buffer.add_goal(processed_goals)
279
+
280
+ img_seq, goal_seq, act_seq = self.vectorized_buffer.get_input_sequence()
281
+
282
+ with torch.no_grad():
283
+ model_input = (img_seq, goal_seq, act_seq)
284
+ model_output = self.model(model_input)
285
+ action_tensors, logs = self.loss_fn.step(
286
+ model_input, model_output, return_all=True
287
+ )
288
+
289
+ action_tensors = action_tensors.squeeze(1).to(self.device)
290
+ self.vectorized_buffer.add_action(action_tensors)
291
+ action_tensors = action_tensors.cpu().numpy()
292
+
293
+ return action_tensors