plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,187 @@
1
+ import glob
2
+ import json
3
+ import random
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ TEMPLATES = {
9
+ "idefics2": {
10
+ "assistant": "\nAssistant:",
11
+ "user": "\nUser:",
12
+ },
13
+ "llama3": {
14
+ "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
15
+ "user": "<|start_header_id|>user<|end_header_id|>\n\n",
16
+ },
17
+ }
18
+
19
+
20
+ class PlancraftDialogueDataset(Dataset):
21
+ def __init__(
22
+ self,
23
+ dataset_dir: str = "data/oracle",
24
+ use_images=False,
25
+ trace_mode="oa",
26
+ split="train",
27
+ max_message_window=30,
28
+ ):
29
+ super().__init__()
30
+ self.split = split
31
+ self.use_images = use_images
32
+ self.trace_mode = trace_mode
33
+ self.add_system_message = True
34
+
35
+ assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
36
+
37
+ print("Loading dialogue dataset")
38
+ data = []
39
+ for example_path in sorted(
40
+ glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
41
+ ):
42
+ with open(example_path) as f:
43
+ example = json.load(f)
44
+ data.append(example)
45
+ self.dataset = data
46
+ self.max_message_window = max_message_window
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.dataset)
50
+
51
+ def __getitem__(self, idx: int) -> tuple[dict, list]:
52
+ example = self.dataset[idx]
53
+ if len(example["messages"]) > self.max_message_window:
54
+ # add system message
55
+ if self.add_system_message:
56
+ messages = [example["messages"][0]]
57
+ else:
58
+ messages = []
59
+
60
+ # sample window
61
+ user_messages_idxs = list(
62
+ range(self.max_message_window, len(example["messages"]), 2)
63
+ )
64
+ end = random.choice(user_messages_idxs)
65
+ start = end - self.max_message_window + 1
66
+ assert start != 0
67
+ # add window
68
+ messages = messages + example["messages"][start : end + 1]
69
+
70
+ else:
71
+ messages = example["messages"]
72
+ return messages
73
+
74
+
75
+ def track_assistant_response(
76
+ batch,
77
+ tokenizer,
78
+ template_name: str = "llama3",
79
+ ):
80
+ """
81
+ Mask that returns 1 for tokens in the assistant response and 0 otherwise.
82
+ """
83
+ assistant_template = TEMPLATES[template_name]["assistant"]
84
+ user_template = TEMPLATES[template_name]["user"]
85
+ start_seq = tokenizer.encode(
86
+ assistant_template,
87
+ add_special_tokens=False,
88
+ return_tensors="pt",
89
+ )[0]
90
+ end_seq = tokenizer.encode(
91
+ user_template,
92
+ add_special_tokens=False,
93
+ return_tensors="pt",
94
+ )[0]
95
+ encoded_label_ids = batch["labels"]
96
+ mask = torch.zeros_like(encoded_label_ids)
97
+ for seq_idx, seq in enumerate(encoded_label_ids):
98
+ in_masked_response = False
99
+ i = 0
100
+ while i < len(seq):
101
+ if i + len(start_seq) < len(seq) and torch.all(
102
+ seq[i : i + len(start_seq)].eq(start_seq)
103
+ ):
104
+ in_masked_response = True
105
+ i += len(start_seq)
106
+ continue
107
+ if i + len(end_seq) < len(seq) and torch.all(
108
+ seq[i : i + len(end_seq)].eq(end_seq)
109
+ ):
110
+ in_masked_response = False
111
+ i += len(end_seq)
112
+ continue
113
+ if in_masked_response:
114
+ mask[seq_idx, i] = 1
115
+ else:
116
+ mask[seq_idx, i] = 0
117
+ i += 1
118
+ return mask
119
+
120
+
121
+ def get_collate_fn(
122
+ tokenizer,
123
+ max_length=8142,
124
+ only_assistant=False,
125
+ template_name: str = "llama3",
126
+ ):
127
+ assert template_name in TEMPLATES
128
+
129
+ def collate_fn(batch):
130
+ messages_batch = []
131
+ for messages in batch:
132
+ text = tokenizer.apply_chat_template(
133
+ messages, add_generation_prompt=False, tokenize=False
134
+ )
135
+ # remove BOS token since it will later be added again by the tokenizer
136
+ text = text.replace("<|begin_of_text|>", "")
137
+ messages_batch.append(text)
138
+
139
+ batch = tokenizer(
140
+ messages_batch,
141
+ truncation=True,
142
+ max_length=max_length,
143
+ return_tensors="pt",
144
+ )
145
+ labels = batch["input_ids"].clone()
146
+ batch["labels"] = labels
147
+
148
+ # add mask for assistant response
149
+ if only_assistant:
150
+ mask = track_assistant_response(
151
+ batch, tokenizer, template_name=template_name
152
+ )
153
+ labels[mask == 0] = -100
154
+
155
+ return batch
156
+
157
+ return collate_fn
158
+
159
+
160
+ def get_dataset_and_collate(
161
+ tokenizer,
162
+ template_name: str = "llama3",
163
+ max_length: int = 8142,
164
+ max_message_window: int = 30,
165
+ trace_mode="oa",
166
+ only_assistant=False,
167
+ ):
168
+ if template_name == "llama3":
169
+ train_dataset = PlancraftDialogueDataset(
170
+ use_images=False,
171
+ max_message_window=max_message_window,
172
+ split="train",
173
+ trace_mode=trace_mode,
174
+ )
175
+ val_dataset = PlancraftDialogueDataset(
176
+ use_images=False,
177
+ max_message_window=max_message_window,
178
+ split="val",
179
+ trace_mode=trace_mode,
180
+ )
181
+ collate_fn = get_collate_fn(
182
+ tokenizer=tokenizer,
183
+ only_assistant=only_assistant,
184
+ template_name=template_name,
185
+ max_length=max_length,
186
+ )
187
+ return train_dataset, val_dataset, collate_fn
plancraft/utils.py ADDED
@@ -0,0 +1,84 @@
1
+ import glob
2
+ import pathlib
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+
10
+
11
+ def get_downloaded_models() -> dict:
12
+ """
13
+ Get the list of downloaded models on the NFS partition (EIDF).
14
+ """
15
+ downloaded_models = {}
16
+ # known models on NFS partition
17
+ if pathlib.Path("/nfs").exists():
18
+ local_models = glob.glob("/nfs/public/hf/models/*/*")
19
+ downloaded_models = {
20
+ model.replace("/nfs/public/hf/models/", ""): model for model in local_models
21
+ }
22
+ return downloaded_models
23
+
24
+
25
+ def get_torch_device() -> torch.device:
26
+ device = torch.device("cpu")
27
+ if torch.cuda.is_available():
28
+ device = torch.device("cuda", 0)
29
+ elif torch.backends.mps.is_available():
30
+ if not torch.backends.mps.is_built():
31
+ print(
32
+ "MPS not available because the current PyTorch install was not built with MPS enabled."
33
+ )
34
+ else:
35
+ device = torch.device("mps")
36
+ return device
37
+
38
+
39
+ def resize_image(img, target_resolution=(128, 128)):
40
+ if type(img) == np.ndarray:
41
+ img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
42
+ elif type(img) == torch.Tensor:
43
+ img = F.interpolate(img, size=target_resolution, mode="bilinear")
44
+ else:
45
+ raise ValueError("Unsupported type for img")
46
+ return img
47
+
48
+
49
+ def save_frames_to_video(frames: list, out_path: str):
50
+ imgs = []
51
+ for id, (frame, goal) in enumerate(frames):
52
+ # if torch.is_tensor(frame):
53
+ # frame = frame.permute(0, 2, 3, 1).cpu().numpy()
54
+
55
+ frame = resize_image(frame, (320, 240)).astype("uint8")
56
+ cv2.putText(
57
+ frame,
58
+ f"FID: {id}",
59
+ (10, 25),
60
+ cv2.FONT_HERSHEY_SIMPLEX,
61
+ 0.8,
62
+ (255, 255, 255),
63
+ 2,
64
+ )
65
+ cv2.putText(
66
+ frame,
67
+ f"Goal: {goal}",
68
+ (10, 55),
69
+ cv2.FONT_HERSHEY_SIMPLEX,
70
+ 0.8,
71
+ (255, 0, 0),
72
+ 2,
73
+ )
74
+ imgs.append(Image.fromarray(frame))
75
+ imgs = imgs[::3]
76
+ imgs[0].save(
77
+ out_path,
78
+ save_all=True,
79
+ append_images=imgs[1:],
80
+ optimize=False,
81
+ quality=0,
82
+ duration=150,
83
+ loop=0,
84
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: plancraft
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Plancraft: an evaluation dataset for planning with LLM agents
5
5
  Requires-Python: >=3.9
6
6
  Description-Content-Type: text/markdown
@@ -0,0 +1,30 @@
1
+ plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ plancraft/config.py,sha256=k7Ac6Zh6g-Q-fuazPxVdXnSL_PMd3jsIqyiImNaJW5A,4472
3
+ plancraft/evaluator.py,sha256=zomShgwyUOHNVZbTERpimgEYiyD4JXzHzPi4tg45lsM,9926
4
+ plancraft/utils.py,sha256=J9K81zBXWCcyIBjwu_eHznnzWamXU-pQrz7BdNTVbLU,2289
5
+ plancraft/environments/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ plancraft/environments/actions.py,sha256=SeeC9l1cJBs9pdba6BefQ_iQNfFf6FVTWm7HWkacbsY,6262
7
+ plancraft/environments/env_real.py,sha256=oETMvdq8-TspPoIRxioTDj9EEmuZRuuMxd9c3momvPk,10877
8
+ plancraft/environments/env_symbolic.py,sha256=ot4IStZ3oT7CYPqdqHsGl2BopLyZUwSF671SrHIiMLk,7777
9
+ plancraft/environments/items.py,sha256=1R56LyK6tqIssQJMqHst6A9DeEfOX5DN-OBAkumGncw,217
10
+ plancraft/environments/planner.py,sha256=2B-0aunllmTuiHE7Jn5jHzCg6mMgxZjisiTDdpSKupk,3954
11
+ plancraft/environments/recipes.py,sha256=nXvOLCRljiZ5IgeevXuosU9IgSs7oQWQJFiuyRNSVFs,19571
12
+ plancraft/environments/sampler.py,sha256=ZBYoENKdQQ7wbAyVk-c9UNRWKPE0omv9he8c8QZ6wXg,7625
13
+ plancraft/models/__init__.py,sha256=PasK3jpbhpD0kxF4iHukcccZqvZg6lL240zie3DfLDY,622
14
+ plancraft/models/act.py,sha256=jdZunT7FcbHvcaJZ_wUDLSuObHjU6JZybghR_B0QJ8Q,6548
15
+ plancraft/models/base.py,sha256=fFM2BV9PqvIFtUlTz8iz5HPemYRy3S0EituM1XdJJSQ,4927
16
+ plancraft/models/bbox_model.py,sha256=CoX-odH59S-djkPOH2ViEmbYWo1sefmHiOcBlFWiAkg,16814
17
+ plancraft/models/dummy.py,sha256=QjxTIiKsWSmhUMAuw7Yy-OKKCLi_x3rwll4hH7ZNXso,1732
18
+ plancraft/models/generators.py,sha256=c2avYUoXAAWOHRFq24ZYqHqNgoPs_L_XWJrR_TKFg9E,17358
19
+ plancraft/models/oam.py,sha256=_TpxsaCnGvQ7YJC6KWcpl2HzgBiQaHn7U1zRAXfdjOo,9943
20
+ plancraft/models/oracle.py,sha256=xRplR2cCW_39UGYtSjDLRDmp0UILhtIXYjuRJ-jokJ8,9598
21
+ plancraft/models/prompts.py,sha256=XwoRqd_5_VfCUXb10dCRFYXgw70mO2VoQocn3Z2zgs0,6165
22
+ plancraft/models/react.py,sha256=5yM4tv7tDfm_-5yUSAcw5C4wioWijF9fLEODbFqFDvg,3346
23
+ plancraft/models/utils.py,sha256=osKX0_uux9wzqYzq1ST0Cu5idrAnyfNvXrj0uO1eKo0,9424
24
+ plancraft/models/few_shot_images/__init__.py,sha256=nIkyB6w3ok-h4lfJKsrcMQzQF624Y9uxYV1FqCu3Lx0,351
25
+ plancraft/train/dataset.py,sha256=NrZjbIkosui1kaq7AIWSrYvvzrDxu_njH7FmGKY3xnI,5434
26
+ plancraft-0.1.3.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
27
+ plancraft-0.1.3.dist-info/METADATA,sha256=XLudjSvyfZ-uYJZ1K1RAJUVOQhozFD3FLzwUQeikhK4,2631
28
+ plancraft-0.1.3.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
29
+ plancraft-0.1.3.dist-info/top_level.txt,sha256=yGGA9HtPKH2-pIJFrbyYhj2JF9Xj-4m0fxMPKT9FNzg,10
30
+ plancraft-0.1.3.dist-info/RECORD,,
@@ -0,0 +1 @@
1
+ plancraft
@@ -1,5 +0,0 @@
1
- plancraft-0.1.2.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
2
- plancraft-0.1.2.dist-info/METADATA,sha256=IDsYHFCvjx4Nhe4Jgyj0x7YsIVAR95ZyrSA-LIQBgeI,2631
3
- plancraft-0.1.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
4
- plancraft-0.1.2.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
5
- plancraft-0.1.2.dist-info/RECORD,,
@@ -1 +0,0 @@
1
-