plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,187 @@
1
+ import glob
2
+ import json
3
+ import random
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ TEMPLATES = {
9
+ "idefics2": {
10
+ "assistant": "\nAssistant:",
11
+ "user": "\nUser:",
12
+ },
13
+ "llama3": {
14
+ "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
15
+ "user": "<|start_header_id|>user<|end_header_id|>\n\n",
16
+ },
17
+ }
18
+
19
+
20
+ class PlancraftDialogueDataset(Dataset):
21
+ def __init__(
22
+ self,
23
+ dataset_dir: str = "data/oracle",
24
+ use_images=False,
25
+ trace_mode="oa",
26
+ split="train",
27
+ max_message_window=30,
28
+ ):
29
+ super().__init__()
30
+ self.split = split
31
+ self.use_images = use_images
32
+ self.trace_mode = trace_mode
33
+ self.add_system_message = True
34
+
35
+ assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
36
+
37
+ print("Loading dialogue dataset")
38
+ data = []
39
+ for example_path in sorted(
40
+ glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
41
+ ):
42
+ with open(example_path) as f:
43
+ example = json.load(f)
44
+ data.append(example)
45
+ self.dataset = data
46
+ self.max_message_window = max_message_window
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.dataset)
50
+
51
+ def __getitem__(self, idx: int) -> tuple[dict, list]:
52
+ example = self.dataset[idx]
53
+ if len(example["messages"]) > self.max_message_window:
54
+ # add system message
55
+ if self.add_system_message:
56
+ messages = [example["messages"][0]]
57
+ else:
58
+ messages = []
59
+
60
+ # sample window
61
+ user_messages_idxs = list(
62
+ range(self.max_message_window, len(example["messages"]), 2)
63
+ )
64
+ end = random.choice(user_messages_idxs)
65
+ start = end - self.max_message_window + 1
66
+ assert start != 0
67
+ # add window
68
+ messages = messages + example["messages"][start : end + 1]
69
+
70
+ else:
71
+ messages = example["messages"]
72
+ return messages
73
+
74
+
75
+ def track_assistant_response(
76
+ batch,
77
+ tokenizer,
78
+ template_name: str = "llama3",
79
+ ):
80
+ """
81
+ Mask that returns 1 for tokens in the assistant response and 0 otherwise.
82
+ """
83
+ assistant_template = TEMPLATES[template_name]["assistant"]
84
+ user_template = TEMPLATES[template_name]["user"]
85
+ start_seq = tokenizer.encode(
86
+ assistant_template,
87
+ add_special_tokens=False,
88
+ return_tensors="pt",
89
+ )[0]
90
+ end_seq = tokenizer.encode(
91
+ user_template,
92
+ add_special_tokens=False,
93
+ return_tensors="pt",
94
+ )[0]
95
+ encoded_label_ids = batch["labels"]
96
+ mask = torch.zeros_like(encoded_label_ids)
97
+ for seq_idx, seq in enumerate(encoded_label_ids):
98
+ in_masked_response = False
99
+ i = 0
100
+ while i < len(seq):
101
+ if i + len(start_seq) < len(seq) and torch.all(
102
+ seq[i : i + len(start_seq)].eq(start_seq)
103
+ ):
104
+ in_masked_response = True
105
+ i += len(start_seq)
106
+ continue
107
+ if i + len(end_seq) < len(seq) and torch.all(
108
+ seq[i : i + len(end_seq)].eq(end_seq)
109
+ ):
110
+ in_masked_response = False
111
+ i += len(end_seq)
112
+ continue
113
+ if in_masked_response:
114
+ mask[seq_idx, i] = 1
115
+ else:
116
+ mask[seq_idx, i] = 0
117
+ i += 1
118
+ return mask
119
+
120
+
121
+ def get_collate_fn(
122
+ tokenizer,
123
+ max_length=8142,
124
+ only_assistant=False,
125
+ template_name: str = "llama3",
126
+ ):
127
+ assert template_name in TEMPLATES
128
+
129
+ def collate_fn(batch):
130
+ messages_batch = []
131
+ for messages in batch:
132
+ text = tokenizer.apply_chat_template(
133
+ messages, add_generation_prompt=False, tokenize=False
134
+ )
135
+ # remove BOS token since it will later be added again by the tokenizer
136
+ text = text.replace("<|begin_of_text|>", "")
137
+ messages_batch.append(text)
138
+
139
+ batch = tokenizer(
140
+ messages_batch,
141
+ truncation=True,
142
+ max_length=max_length,
143
+ return_tensors="pt",
144
+ )
145
+ labels = batch["input_ids"].clone()
146
+ batch["labels"] = labels
147
+
148
+ # add mask for assistant response
149
+ if only_assistant:
150
+ mask = track_assistant_response(
151
+ batch, tokenizer, template_name=template_name
152
+ )
153
+ labels[mask == 0] = -100
154
+
155
+ return batch
156
+
157
+ return collate_fn
158
+
159
+
160
+ def get_dataset_and_collate(
161
+ tokenizer,
162
+ template_name: str = "llama3",
163
+ max_length: int = 8142,
164
+ max_message_window: int = 30,
165
+ trace_mode="oa",
166
+ only_assistant=False,
167
+ ):
168
+ if template_name == "llama3":
169
+ train_dataset = PlancraftDialogueDataset(
170
+ use_images=False,
171
+ max_message_window=max_message_window,
172
+ split="train",
173
+ trace_mode=trace_mode,
174
+ )
175
+ val_dataset = PlancraftDialogueDataset(
176
+ use_images=False,
177
+ max_message_window=max_message_window,
178
+ split="val",
179
+ trace_mode=trace_mode,
180
+ )
181
+ collate_fn = get_collate_fn(
182
+ tokenizer=tokenizer,
183
+ only_assistant=only_assistant,
184
+ template_name=template_name,
185
+ max_length=max_length,
186
+ )
187
+ return train_dataset, val_dataset, collate_fn
plancraft/utils.py ADDED
@@ -0,0 +1,84 @@
1
+ import glob
2
+ import pathlib
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+
10
+
11
+ def get_downloaded_models() -> dict:
12
+ """
13
+ Get the list of downloaded models on the NFS partition (EIDF).
14
+ """
15
+ downloaded_models = {}
16
+ # known models on NFS partition
17
+ if pathlib.Path("/nfs").exists():
18
+ local_models = glob.glob("/nfs/public/hf/models/*/*")
19
+ downloaded_models = {
20
+ model.replace("/nfs/public/hf/models/", ""): model for model in local_models
21
+ }
22
+ return downloaded_models
23
+
24
+
25
+ def get_torch_device() -> torch.device:
26
+ device = torch.device("cpu")
27
+ if torch.cuda.is_available():
28
+ device = torch.device("cuda", 0)
29
+ elif torch.backends.mps.is_available():
30
+ if not torch.backends.mps.is_built():
31
+ print(
32
+ "MPS not available because the current PyTorch install was not built with MPS enabled."
33
+ )
34
+ else:
35
+ device = torch.device("mps")
36
+ return device
37
+
38
+
39
+ def resize_image(img, target_resolution=(128, 128)):
40
+ if type(img) == np.ndarray:
41
+ img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
42
+ elif type(img) == torch.Tensor:
43
+ img = F.interpolate(img, size=target_resolution, mode="bilinear")
44
+ else:
45
+ raise ValueError("Unsupported type for img")
46
+ return img
47
+
48
+
49
+ def save_frames_to_video(frames: list, out_path: str):
50
+ imgs = []
51
+ for id, (frame, goal) in enumerate(frames):
52
+ # if torch.is_tensor(frame):
53
+ # frame = frame.permute(0, 2, 3, 1).cpu().numpy()
54
+
55
+ frame = resize_image(frame, (320, 240)).astype("uint8")
56
+ cv2.putText(
57
+ frame,
58
+ f"FID: {id}",
59
+ (10, 25),
60
+ cv2.FONT_HERSHEY_SIMPLEX,
61
+ 0.8,
62
+ (255, 255, 255),
63
+ 2,
64
+ )
65
+ cv2.putText(
66
+ frame,
67
+ f"Goal: {goal}",
68
+ (10, 55),
69
+ cv2.FONT_HERSHEY_SIMPLEX,
70
+ 0.8,
71
+ (255, 0, 0),
72
+ 2,
73
+ )
74
+ imgs.append(Image.fromarray(frame))
75
+ imgs = imgs[::3]
76
+ imgs[0].save(
77
+ out_path,
78
+ save_all=True,
79
+ append_images=imgs[1:],
80
+ optimize=False,
81
+ quality=0,
82
+ duration=150,
83
+ loop=0,
84
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: plancraft
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Plancraft: an evaluation dataset for planning with LLM agents
5
5
  Requires-Python: >=3.9
6
6
  Description-Content-Type: text/markdown
@@ -0,0 +1,30 @@
1
+ plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ plancraft/config.py,sha256=k7Ac6Zh6g-Q-fuazPxVdXnSL_PMd3jsIqyiImNaJW5A,4472
3
+ plancraft/evaluator.py,sha256=zomShgwyUOHNVZbTERpimgEYiyD4JXzHzPi4tg45lsM,9926
4
+ plancraft/utils.py,sha256=J9K81zBXWCcyIBjwu_eHznnzWamXU-pQrz7BdNTVbLU,2289
5
+ plancraft/environments/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ plancraft/environments/actions.py,sha256=SeeC9l1cJBs9pdba6BefQ_iQNfFf6FVTWm7HWkacbsY,6262
7
+ plancraft/environments/env_real.py,sha256=oETMvdq8-TspPoIRxioTDj9EEmuZRuuMxd9c3momvPk,10877
8
+ plancraft/environments/env_symbolic.py,sha256=ot4IStZ3oT7CYPqdqHsGl2BopLyZUwSF671SrHIiMLk,7777
9
+ plancraft/environments/items.py,sha256=1R56LyK6tqIssQJMqHst6A9DeEfOX5DN-OBAkumGncw,217
10
+ plancraft/environments/planner.py,sha256=2B-0aunllmTuiHE7Jn5jHzCg6mMgxZjisiTDdpSKupk,3954
11
+ plancraft/environments/recipes.py,sha256=nXvOLCRljiZ5IgeevXuosU9IgSs7oQWQJFiuyRNSVFs,19571
12
+ plancraft/environments/sampler.py,sha256=ZBYoENKdQQ7wbAyVk-c9UNRWKPE0omv9he8c8QZ6wXg,7625
13
+ plancraft/models/__init__.py,sha256=PasK3jpbhpD0kxF4iHukcccZqvZg6lL240zie3DfLDY,622
14
+ plancraft/models/act.py,sha256=jdZunT7FcbHvcaJZ_wUDLSuObHjU6JZybghR_B0QJ8Q,6548
15
+ plancraft/models/base.py,sha256=fFM2BV9PqvIFtUlTz8iz5HPemYRy3S0EituM1XdJJSQ,4927
16
+ plancraft/models/bbox_model.py,sha256=CoX-odH59S-djkPOH2ViEmbYWo1sefmHiOcBlFWiAkg,16814
17
+ plancraft/models/dummy.py,sha256=QjxTIiKsWSmhUMAuw7Yy-OKKCLi_x3rwll4hH7ZNXso,1732
18
+ plancraft/models/generators.py,sha256=c2avYUoXAAWOHRFq24ZYqHqNgoPs_L_XWJrR_TKFg9E,17358
19
+ plancraft/models/oam.py,sha256=_TpxsaCnGvQ7YJC6KWcpl2HzgBiQaHn7U1zRAXfdjOo,9943
20
+ plancraft/models/oracle.py,sha256=xRplR2cCW_39UGYtSjDLRDmp0UILhtIXYjuRJ-jokJ8,9598
21
+ plancraft/models/prompts.py,sha256=XwoRqd_5_VfCUXb10dCRFYXgw70mO2VoQocn3Z2zgs0,6165
22
+ plancraft/models/react.py,sha256=5yM4tv7tDfm_-5yUSAcw5C4wioWijF9fLEODbFqFDvg,3346
23
+ plancraft/models/utils.py,sha256=osKX0_uux9wzqYzq1ST0Cu5idrAnyfNvXrj0uO1eKo0,9424
24
+ plancraft/models/few_shot_images/__init__.py,sha256=nIkyB6w3ok-h4lfJKsrcMQzQF624Y9uxYV1FqCu3Lx0,351
25
+ plancraft/train/dataset.py,sha256=NrZjbIkosui1kaq7AIWSrYvvzrDxu_njH7FmGKY3xnI,5434
26
+ plancraft-0.1.3.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
27
+ plancraft-0.1.3.dist-info/METADATA,sha256=XLudjSvyfZ-uYJZ1K1RAJUVOQhozFD3FLzwUQeikhK4,2631
28
+ plancraft-0.1.3.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
29
+ plancraft-0.1.3.dist-info/top_level.txt,sha256=yGGA9HtPKH2-pIJFrbyYhj2JF9Xj-4m0fxMPKT9FNzg,10
30
+ plancraft-0.1.3.dist-info/RECORD,,
@@ -0,0 +1 @@
1
+ plancraft
@@ -1,5 +0,0 @@
1
- plancraft-0.1.2.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
2
- plancraft-0.1.2.dist-info/METADATA,sha256=IDsYHFCvjx4Nhe4Jgyj0x7YsIVAR95ZyrSA-LIQBgeI,2631
3
- plancraft-0.1.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
4
- plancraft-0.1.2.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
5
- plancraft-0.1.2.dist-info/RECORD,,
@@ -1 +0,0 @@
1
-