plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/__init__.py +0 -0
- plancraft/config.py +155 -0
- plancraft/environments/__init__.py +0 -0
- plancraft/environments/actions.py +218 -0
- plancraft/environments/env_real.py +316 -0
- plancraft/environments/env_symbolic.py +212 -0
- plancraft/environments/items.py +10 -0
- plancraft/environments/planner.py +109 -0
- plancraft/environments/recipes.py +542 -0
- plancraft/environments/sampler.py +224 -0
- plancraft/evaluator.py +273 -0
- plancraft/models/__init__.py +21 -0
- plancraft/models/act.py +184 -0
- plancraft/models/base.py +152 -0
- plancraft/models/bbox_model.py +492 -0
- plancraft/models/dummy.py +54 -0
- plancraft/models/few_shot_images/__init__.py +16 -0
- plancraft/models/generators.py +480 -0
- plancraft/models/oam.py +283 -0
- plancraft/models/oracle.py +265 -0
- plancraft/models/prompts.py +158 -0
- plancraft/models/react.py +93 -0
- plancraft/models/utils.py +289 -0
- plancraft/train/dataset.py +187 -0
- plancraft/utils.py +84 -0
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/METADATA +1 -1
- plancraft-0.1.3.dist-info/RECORD +30 -0
- plancraft-0.1.3.dist-info/top_level.txt +1 -0
- plancraft-0.1.2.dist-info/RECORD +0 -5
- plancraft-0.1.2.dist-info/top_level.txt +0 -1
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/LICENSE +0 -0
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,187 @@
|
|
1
|
+
import glob
|
2
|
+
import json
|
3
|
+
import random
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.utils.data import Dataset
|
7
|
+
|
8
|
+
TEMPLATES = {
|
9
|
+
"idefics2": {
|
10
|
+
"assistant": "\nAssistant:",
|
11
|
+
"user": "\nUser:",
|
12
|
+
},
|
13
|
+
"llama3": {
|
14
|
+
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
15
|
+
"user": "<|start_header_id|>user<|end_header_id|>\n\n",
|
16
|
+
},
|
17
|
+
}
|
18
|
+
|
19
|
+
|
20
|
+
class PlancraftDialogueDataset(Dataset):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
dataset_dir: str = "data/oracle",
|
24
|
+
use_images=False,
|
25
|
+
trace_mode="oa",
|
26
|
+
split="train",
|
27
|
+
max_message_window=30,
|
28
|
+
):
|
29
|
+
super().__init__()
|
30
|
+
self.split = split
|
31
|
+
self.use_images = use_images
|
32
|
+
self.trace_mode = trace_mode
|
33
|
+
self.add_system_message = True
|
34
|
+
|
35
|
+
assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
|
36
|
+
|
37
|
+
print("Loading dialogue dataset")
|
38
|
+
data = []
|
39
|
+
for example_path in sorted(
|
40
|
+
glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
|
41
|
+
):
|
42
|
+
with open(example_path) as f:
|
43
|
+
example = json.load(f)
|
44
|
+
data.append(example)
|
45
|
+
self.dataset = data
|
46
|
+
self.max_message_window = max_message_window
|
47
|
+
|
48
|
+
def __len__(self) -> int:
|
49
|
+
return len(self.dataset)
|
50
|
+
|
51
|
+
def __getitem__(self, idx: int) -> tuple[dict, list]:
|
52
|
+
example = self.dataset[idx]
|
53
|
+
if len(example["messages"]) > self.max_message_window:
|
54
|
+
# add system message
|
55
|
+
if self.add_system_message:
|
56
|
+
messages = [example["messages"][0]]
|
57
|
+
else:
|
58
|
+
messages = []
|
59
|
+
|
60
|
+
# sample window
|
61
|
+
user_messages_idxs = list(
|
62
|
+
range(self.max_message_window, len(example["messages"]), 2)
|
63
|
+
)
|
64
|
+
end = random.choice(user_messages_idxs)
|
65
|
+
start = end - self.max_message_window + 1
|
66
|
+
assert start != 0
|
67
|
+
# add window
|
68
|
+
messages = messages + example["messages"][start : end + 1]
|
69
|
+
|
70
|
+
else:
|
71
|
+
messages = example["messages"]
|
72
|
+
return messages
|
73
|
+
|
74
|
+
|
75
|
+
def track_assistant_response(
|
76
|
+
batch,
|
77
|
+
tokenizer,
|
78
|
+
template_name: str = "llama3",
|
79
|
+
):
|
80
|
+
"""
|
81
|
+
Mask that returns 1 for tokens in the assistant response and 0 otherwise.
|
82
|
+
"""
|
83
|
+
assistant_template = TEMPLATES[template_name]["assistant"]
|
84
|
+
user_template = TEMPLATES[template_name]["user"]
|
85
|
+
start_seq = tokenizer.encode(
|
86
|
+
assistant_template,
|
87
|
+
add_special_tokens=False,
|
88
|
+
return_tensors="pt",
|
89
|
+
)[0]
|
90
|
+
end_seq = tokenizer.encode(
|
91
|
+
user_template,
|
92
|
+
add_special_tokens=False,
|
93
|
+
return_tensors="pt",
|
94
|
+
)[0]
|
95
|
+
encoded_label_ids = batch["labels"]
|
96
|
+
mask = torch.zeros_like(encoded_label_ids)
|
97
|
+
for seq_idx, seq in enumerate(encoded_label_ids):
|
98
|
+
in_masked_response = False
|
99
|
+
i = 0
|
100
|
+
while i < len(seq):
|
101
|
+
if i + len(start_seq) < len(seq) and torch.all(
|
102
|
+
seq[i : i + len(start_seq)].eq(start_seq)
|
103
|
+
):
|
104
|
+
in_masked_response = True
|
105
|
+
i += len(start_seq)
|
106
|
+
continue
|
107
|
+
if i + len(end_seq) < len(seq) and torch.all(
|
108
|
+
seq[i : i + len(end_seq)].eq(end_seq)
|
109
|
+
):
|
110
|
+
in_masked_response = False
|
111
|
+
i += len(end_seq)
|
112
|
+
continue
|
113
|
+
if in_masked_response:
|
114
|
+
mask[seq_idx, i] = 1
|
115
|
+
else:
|
116
|
+
mask[seq_idx, i] = 0
|
117
|
+
i += 1
|
118
|
+
return mask
|
119
|
+
|
120
|
+
|
121
|
+
def get_collate_fn(
|
122
|
+
tokenizer,
|
123
|
+
max_length=8142,
|
124
|
+
only_assistant=False,
|
125
|
+
template_name: str = "llama3",
|
126
|
+
):
|
127
|
+
assert template_name in TEMPLATES
|
128
|
+
|
129
|
+
def collate_fn(batch):
|
130
|
+
messages_batch = []
|
131
|
+
for messages in batch:
|
132
|
+
text = tokenizer.apply_chat_template(
|
133
|
+
messages, add_generation_prompt=False, tokenize=False
|
134
|
+
)
|
135
|
+
# remove BOS token since it will later be added again by the tokenizer
|
136
|
+
text = text.replace("<|begin_of_text|>", "")
|
137
|
+
messages_batch.append(text)
|
138
|
+
|
139
|
+
batch = tokenizer(
|
140
|
+
messages_batch,
|
141
|
+
truncation=True,
|
142
|
+
max_length=max_length,
|
143
|
+
return_tensors="pt",
|
144
|
+
)
|
145
|
+
labels = batch["input_ids"].clone()
|
146
|
+
batch["labels"] = labels
|
147
|
+
|
148
|
+
# add mask for assistant response
|
149
|
+
if only_assistant:
|
150
|
+
mask = track_assistant_response(
|
151
|
+
batch, tokenizer, template_name=template_name
|
152
|
+
)
|
153
|
+
labels[mask == 0] = -100
|
154
|
+
|
155
|
+
return batch
|
156
|
+
|
157
|
+
return collate_fn
|
158
|
+
|
159
|
+
|
160
|
+
def get_dataset_and_collate(
|
161
|
+
tokenizer,
|
162
|
+
template_name: str = "llama3",
|
163
|
+
max_length: int = 8142,
|
164
|
+
max_message_window: int = 30,
|
165
|
+
trace_mode="oa",
|
166
|
+
only_assistant=False,
|
167
|
+
):
|
168
|
+
if template_name == "llama3":
|
169
|
+
train_dataset = PlancraftDialogueDataset(
|
170
|
+
use_images=False,
|
171
|
+
max_message_window=max_message_window,
|
172
|
+
split="train",
|
173
|
+
trace_mode=trace_mode,
|
174
|
+
)
|
175
|
+
val_dataset = PlancraftDialogueDataset(
|
176
|
+
use_images=False,
|
177
|
+
max_message_window=max_message_window,
|
178
|
+
split="val",
|
179
|
+
trace_mode=trace_mode,
|
180
|
+
)
|
181
|
+
collate_fn = get_collate_fn(
|
182
|
+
tokenizer=tokenizer,
|
183
|
+
only_assistant=only_assistant,
|
184
|
+
template_name=template_name,
|
185
|
+
max_length=max_length,
|
186
|
+
)
|
187
|
+
return train_dataset, val_dataset, collate_fn
|
plancraft/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
import glob
|
2
|
+
import pathlib
|
3
|
+
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from PIL import Image
|
9
|
+
|
10
|
+
|
11
|
+
def get_downloaded_models() -> dict:
|
12
|
+
"""
|
13
|
+
Get the list of downloaded models on the NFS partition (EIDF).
|
14
|
+
"""
|
15
|
+
downloaded_models = {}
|
16
|
+
# known models on NFS partition
|
17
|
+
if pathlib.Path("/nfs").exists():
|
18
|
+
local_models = glob.glob("/nfs/public/hf/models/*/*")
|
19
|
+
downloaded_models = {
|
20
|
+
model.replace("/nfs/public/hf/models/", ""): model for model in local_models
|
21
|
+
}
|
22
|
+
return downloaded_models
|
23
|
+
|
24
|
+
|
25
|
+
def get_torch_device() -> torch.device:
|
26
|
+
device = torch.device("cpu")
|
27
|
+
if torch.cuda.is_available():
|
28
|
+
device = torch.device("cuda", 0)
|
29
|
+
elif torch.backends.mps.is_available():
|
30
|
+
if not torch.backends.mps.is_built():
|
31
|
+
print(
|
32
|
+
"MPS not available because the current PyTorch install was not built with MPS enabled."
|
33
|
+
)
|
34
|
+
else:
|
35
|
+
device = torch.device("mps")
|
36
|
+
return device
|
37
|
+
|
38
|
+
|
39
|
+
def resize_image(img, target_resolution=(128, 128)):
|
40
|
+
if type(img) == np.ndarray:
|
41
|
+
img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
|
42
|
+
elif type(img) == torch.Tensor:
|
43
|
+
img = F.interpolate(img, size=target_resolution, mode="bilinear")
|
44
|
+
else:
|
45
|
+
raise ValueError("Unsupported type for img")
|
46
|
+
return img
|
47
|
+
|
48
|
+
|
49
|
+
def save_frames_to_video(frames: list, out_path: str):
|
50
|
+
imgs = []
|
51
|
+
for id, (frame, goal) in enumerate(frames):
|
52
|
+
# if torch.is_tensor(frame):
|
53
|
+
# frame = frame.permute(0, 2, 3, 1).cpu().numpy()
|
54
|
+
|
55
|
+
frame = resize_image(frame, (320, 240)).astype("uint8")
|
56
|
+
cv2.putText(
|
57
|
+
frame,
|
58
|
+
f"FID: {id}",
|
59
|
+
(10, 25),
|
60
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
61
|
+
0.8,
|
62
|
+
(255, 255, 255),
|
63
|
+
2,
|
64
|
+
)
|
65
|
+
cv2.putText(
|
66
|
+
frame,
|
67
|
+
f"Goal: {goal}",
|
68
|
+
(10, 55),
|
69
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
70
|
+
0.8,
|
71
|
+
(255, 0, 0),
|
72
|
+
2,
|
73
|
+
)
|
74
|
+
imgs.append(Image.fromarray(frame))
|
75
|
+
imgs = imgs[::3]
|
76
|
+
imgs[0].save(
|
77
|
+
out_path,
|
78
|
+
save_all=True,
|
79
|
+
append_images=imgs[1:],
|
80
|
+
optimize=False,
|
81
|
+
quality=0,
|
82
|
+
duration=150,
|
83
|
+
loop=0,
|
84
|
+
)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
plancraft/config.py,sha256=k7Ac6Zh6g-Q-fuazPxVdXnSL_PMd3jsIqyiImNaJW5A,4472
|
3
|
+
plancraft/evaluator.py,sha256=zomShgwyUOHNVZbTERpimgEYiyD4JXzHzPi4tg45lsM,9926
|
4
|
+
plancraft/utils.py,sha256=J9K81zBXWCcyIBjwu_eHznnzWamXU-pQrz7BdNTVbLU,2289
|
5
|
+
plancraft/environments/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
plancraft/environments/actions.py,sha256=SeeC9l1cJBs9pdba6BefQ_iQNfFf6FVTWm7HWkacbsY,6262
|
7
|
+
plancraft/environments/env_real.py,sha256=oETMvdq8-TspPoIRxioTDj9EEmuZRuuMxd9c3momvPk,10877
|
8
|
+
plancraft/environments/env_symbolic.py,sha256=ot4IStZ3oT7CYPqdqHsGl2BopLyZUwSF671SrHIiMLk,7777
|
9
|
+
plancraft/environments/items.py,sha256=1R56LyK6tqIssQJMqHst6A9DeEfOX5DN-OBAkumGncw,217
|
10
|
+
plancraft/environments/planner.py,sha256=2B-0aunllmTuiHE7Jn5jHzCg6mMgxZjisiTDdpSKupk,3954
|
11
|
+
plancraft/environments/recipes.py,sha256=nXvOLCRljiZ5IgeevXuosU9IgSs7oQWQJFiuyRNSVFs,19571
|
12
|
+
plancraft/environments/sampler.py,sha256=ZBYoENKdQQ7wbAyVk-c9UNRWKPE0omv9he8c8QZ6wXg,7625
|
13
|
+
plancraft/models/__init__.py,sha256=PasK3jpbhpD0kxF4iHukcccZqvZg6lL240zie3DfLDY,622
|
14
|
+
plancraft/models/act.py,sha256=jdZunT7FcbHvcaJZ_wUDLSuObHjU6JZybghR_B0QJ8Q,6548
|
15
|
+
plancraft/models/base.py,sha256=fFM2BV9PqvIFtUlTz8iz5HPemYRy3S0EituM1XdJJSQ,4927
|
16
|
+
plancraft/models/bbox_model.py,sha256=CoX-odH59S-djkPOH2ViEmbYWo1sefmHiOcBlFWiAkg,16814
|
17
|
+
plancraft/models/dummy.py,sha256=QjxTIiKsWSmhUMAuw7Yy-OKKCLi_x3rwll4hH7ZNXso,1732
|
18
|
+
plancraft/models/generators.py,sha256=c2avYUoXAAWOHRFq24ZYqHqNgoPs_L_XWJrR_TKFg9E,17358
|
19
|
+
plancraft/models/oam.py,sha256=_TpxsaCnGvQ7YJC6KWcpl2HzgBiQaHn7U1zRAXfdjOo,9943
|
20
|
+
plancraft/models/oracle.py,sha256=xRplR2cCW_39UGYtSjDLRDmp0UILhtIXYjuRJ-jokJ8,9598
|
21
|
+
plancraft/models/prompts.py,sha256=XwoRqd_5_VfCUXb10dCRFYXgw70mO2VoQocn3Z2zgs0,6165
|
22
|
+
plancraft/models/react.py,sha256=5yM4tv7tDfm_-5yUSAcw5C4wioWijF9fLEODbFqFDvg,3346
|
23
|
+
plancraft/models/utils.py,sha256=osKX0_uux9wzqYzq1ST0Cu5idrAnyfNvXrj0uO1eKo0,9424
|
24
|
+
plancraft/models/few_shot_images/__init__.py,sha256=nIkyB6w3ok-h4lfJKsrcMQzQF624Y9uxYV1FqCu3Lx0,351
|
25
|
+
plancraft/train/dataset.py,sha256=NrZjbIkosui1kaq7AIWSrYvvzrDxu_njH7FmGKY3xnI,5434
|
26
|
+
plancraft-0.1.3.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
27
|
+
plancraft-0.1.3.dist-info/METADATA,sha256=XLudjSvyfZ-uYJZ1K1RAJUVOQhozFD3FLzwUQeikhK4,2631
|
28
|
+
plancraft-0.1.3.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
29
|
+
plancraft-0.1.3.dist-info/top_level.txt,sha256=yGGA9HtPKH2-pIJFrbyYhj2JF9Xj-4m0fxMPKT9FNzg,10
|
30
|
+
plancraft-0.1.3.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
plancraft
|
plancraft-0.1.2.dist-info/RECORD
DELETED
@@ -1,5 +0,0 @@
|
|
1
|
-
plancraft-0.1.2.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
2
|
-
plancraft-0.1.2.dist-info/METADATA,sha256=IDsYHFCvjx4Nhe4Jgyj0x7YsIVAR95ZyrSA-LIQBgeI,2631
|
3
|
-
plancraft-0.1.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
4
|
-
plancraft-0.1.2.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
5
|
-
plancraft-0.1.2.dist-info/RECORD,,
|
@@ -1 +0,0 @@
|
|
1
|
-
|
File without changes
|
File without changes
|