plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- plancraft/__init__.py +0 -0
- plancraft/config.py +155 -0
- plancraft/environments/__init__.py +0 -0
- plancraft/environments/actions.py +218 -0
- plancraft/environments/env_real.py +316 -0
- plancraft/environments/env_symbolic.py +212 -0
- plancraft/environments/items.py +10 -0
- plancraft/environments/planner.py +109 -0
- plancraft/environments/recipes.py +542 -0
- plancraft/environments/sampler.py +224 -0
- plancraft/evaluator.py +273 -0
- plancraft/models/__init__.py +21 -0
- plancraft/models/act.py +184 -0
- plancraft/models/base.py +152 -0
- plancraft/models/bbox_model.py +492 -0
- plancraft/models/dummy.py +54 -0
- plancraft/models/few_shot_images/__init__.py +16 -0
- plancraft/models/generators.py +480 -0
- plancraft/models/oam.py +283 -0
- plancraft/models/oracle.py +265 -0
- plancraft/models/prompts.py +158 -0
- plancraft/models/react.py +93 -0
- plancraft/models/utils.py +289 -0
- plancraft/train/dataset.py +187 -0
- plancraft/utils.py +84 -0
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/METADATA +1 -1
- plancraft-0.1.3.dist-info/RECORD +30 -0
- plancraft-0.1.3.dist-info/top_level.txt +1 -0
- plancraft-0.1.2.dist-info/RECORD +0 -5
- plancraft-0.1.2.dist-info/top_level.txt +0 -1
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/LICENSE +0 -0
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,187 @@
|
|
1
|
+
import glob
|
2
|
+
import json
|
3
|
+
import random
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.utils.data import Dataset
|
7
|
+
|
8
|
+
TEMPLATES = {
|
9
|
+
"idefics2": {
|
10
|
+
"assistant": "\nAssistant:",
|
11
|
+
"user": "\nUser:",
|
12
|
+
},
|
13
|
+
"llama3": {
|
14
|
+
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
15
|
+
"user": "<|start_header_id|>user<|end_header_id|>\n\n",
|
16
|
+
},
|
17
|
+
}
|
18
|
+
|
19
|
+
|
20
|
+
class PlancraftDialogueDataset(Dataset):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
dataset_dir: str = "data/oracle",
|
24
|
+
use_images=False,
|
25
|
+
trace_mode="oa",
|
26
|
+
split="train",
|
27
|
+
max_message_window=30,
|
28
|
+
):
|
29
|
+
super().__init__()
|
30
|
+
self.split = split
|
31
|
+
self.use_images = use_images
|
32
|
+
self.trace_mode = trace_mode
|
33
|
+
self.add_system_message = True
|
34
|
+
|
35
|
+
assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
|
36
|
+
|
37
|
+
print("Loading dialogue dataset")
|
38
|
+
data = []
|
39
|
+
for example_path in sorted(
|
40
|
+
glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
|
41
|
+
):
|
42
|
+
with open(example_path) as f:
|
43
|
+
example = json.load(f)
|
44
|
+
data.append(example)
|
45
|
+
self.dataset = data
|
46
|
+
self.max_message_window = max_message_window
|
47
|
+
|
48
|
+
def __len__(self) -> int:
|
49
|
+
return len(self.dataset)
|
50
|
+
|
51
|
+
def __getitem__(self, idx: int) -> tuple[dict, list]:
|
52
|
+
example = self.dataset[idx]
|
53
|
+
if len(example["messages"]) > self.max_message_window:
|
54
|
+
# add system message
|
55
|
+
if self.add_system_message:
|
56
|
+
messages = [example["messages"][0]]
|
57
|
+
else:
|
58
|
+
messages = []
|
59
|
+
|
60
|
+
# sample window
|
61
|
+
user_messages_idxs = list(
|
62
|
+
range(self.max_message_window, len(example["messages"]), 2)
|
63
|
+
)
|
64
|
+
end = random.choice(user_messages_idxs)
|
65
|
+
start = end - self.max_message_window + 1
|
66
|
+
assert start != 0
|
67
|
+
# add window
|
68
|
+
messages = messages + example["messages"][start : end + 1]
|
69
|
+
|
70
|
+
else:
|
71
|
+
messages = example["messages"]
|
72
|
+
return messages
|
73
|
+
|
74
|
+
|
75
|
+
def track_assistant_response(
|
76
|
+
batch,
|
77
|
+
tokenizer,
|
78
|
+
template_name: str = "llama3",
|
79
|
+
):
|
80
|
+
"""
|
81
|
+
Mask that returns 1 for tokens in the assistant response and 0 otherwise.
|
82
|
+
"""
|
83
|
+
assistant_template = TEMPLATES[template_name]["assistant"]
|
84
|
+
user_template = TEMPLATES[template_name]["user"]
|
85
|
+
start_seq = tokenizer.encode(
|
86
|
+
assistant_template,
|
87
|
+
add_special_tokens=False,
|
88
|
+
return_tensors="pt",
|
89
|
+
)[0]
|
90
|
+
end_seq = tokenizer.encode(
|
91
|
+
user_template,
|
92
|
+
add_special_tokens=False,
|
93
|
+
return_tensors="pt",
|
94
|
+
)[0]
|
95
|
+
encoded_label_ids = batch["labels"]
|
96
|
+
mask = torch.zeros_like(encoded_label_ids)
|
97
|
+
for seq_idx, seq in enumerate(encoded_label_ids):
|
98
|
+
in_masked_response = False
|
99
|
+
i = 0
|
100
|
+
while i < len(seq):
|
101
|
+
if i + len(start_seq) < len(seq) and torch.all(
|
102
|
+
seq[i : i + len(start_seq)].eq(start_seq)
|
103
|
+
):
|
104
|
+
in_masked_response = True
|
105
|
+
i += len(start_seq)
|
106
|
+
continue
|
107
|
+
if i + len(end_seq) < len(seq) and torch.all(
|
108
|
+
seq[i : i + len(end_seq)].eq(end_seq)
|
109
|
+
):
|
110
|
+
in_masked_response = False
|
111
|
+
i += len(end_seq)
|
112
|
+
continue
|
113
|
+
if in_masked_response:
|
114
|
+
mask[seq_idx, i] = 1
|
115
|
+
else:
|
116
|
+
mask[seq_idx, i] = 0
|
117
|
+
i += 1
|
118
|
+
return mask
|
119
|
+
|
120
|
+
|
121
|
+
def get_collate_fn(
|
122
|
+
tokenizer,
|
123
|
+
max_length=8142,
|
124
|
+
only_assistant=False,
|
125
|
+
template_name: str = "llama3",
|
126
|
+
):
|
127
|
+
assert template_name in TEMPLATES
|
128
|
+
|
129
|
+
def collate_fn(batch):
|
130
|
+
messages_batch = []
|
131
|
+
for messages in batch:
|
132
|
+
text = tokenizer.apply_chat_template(
|
133
|
+
messages, add_generation_prompt=False, tokenize=False
|
134
|
+
)
|
135
|
+
# remove BOS token since it will later be added again by the tokenizer
|
136
|
+
text = text.replace("<|begin_of_text|>", "")
|
137
|
+
messages_batch.append(text)
|
138
|
+
|
139
|
+
batch = tokenizer(
|
140
|
+
messages_batch,
|
141
|
+
truncation=True,
|
142
|
+
max_length=max_length,
|
143
|
+
return_tensors="pt",
|
144
|
+
)
|
145
|
+
labels = batch["input_ids"].clone()
|
146
|
+
batch["labels"] = labels
|
147
|
+
|
148
|
+
# add mask for assistant response
|
149
|
+
if only_assistant:
|
150
|
+
mask = track_assistant_response(
|
151
|
+
batch, tokenizer, template_name=template_name
|
152
|
+
)
|
153
|
+
labels[mask == 0] = -100
|
154
|
+
|
155
|
+
return batch
|
156
|
+
|
157
|
+
return collate_fn
|
158
|
+
|
159
|
+
|
160
|
+
def get_dataset_and_collate(
|
161
|
+
tokenizer,
|
162
|
+
template_name: str = "llama3",
|
163
|
+
max_length: int = 8142,
|
164
|
+
max_message_window: int = 30,
|
165
|
+
trace_mode="oa",
|
166
|
+
only_assistant=False,
|
167
|
+
):
|
168
|
+
if template_name == "llama3":
|
169
|
+
train_dataset = PlancraftDialogueDataset(
|
170
|
+
use_images=False,
|
171
|
+
max_message_window=max_message_window,
|
172
|
+
split="train",
|
173
|
+
trace_mode=trace_mode,
|
174
|
+
)
|
175
|
+
val_dataset = PlancraftDialogueDataset(
|
176
|
+
use_images=False,
|
177
|
+
max_message_window=max_message_window,
|
178
|
+
split="val",
|
179
|
+
trace_mode=trace_mode,
|
180
|
+
)
|
181
|
+
collate_fn = get_collate_fn(
|
182
|
+
tokenizer=tokenizer,
|
183
|
+
only_assistant=only_assistant,
|
184
|
+
template_name=template_name,
|
185
|
+
max_length=max_length,
|
186
|
+
)
|
187
|
+
return train_dataset, val_dataset, collate_fn
|
plancraft/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
import glob
|
2
|
+
import pathlib
|
3
|
+
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from PIL import Image
|
9
|
+
|
10
|
+
|
11
|
+
def get_downloaded_models() -> dict:
|
12
|
+
"""
|
13
|
+
Get the list of downloaded models on the NFS partition (EIDF).
|
14
|
+
"""
|
15
|
+
downloaded_models = {}
|
16
|
+
# known models on NFS partition
|
17
|
+
if pathlib.Path("/nfs").exists():
|
18
|
+
local_models = glob.glob("/nfs/public/hf/models/*/*")
|
19
|
+
downloaded_models = {
|
20
|
+
model.replace("/nfs/public/hf/models/", ""): model for model in local_models
|
21
|
+
}
|
22
|
+
return downloaded_models
|
23
|
+
|
24
|
+
|
25
|
+
def get_torch_device() -> torch.device:
|
26
|
+
device = torch.device("cpu")
|
27
|
+
if torch.cuda.is_available():
|
28
|
+
device = torch.device("cuda", 0)
|
29
|
+
elif torch.backends.mps.is_available():
|
30
|
+
if not torch.backends.mps.is_built():
|
31
|
+
print(
|
32
|
+
"MPS not available because the current PyTorch install was not built with MPS enabled."
|
33
|
+
)
|
34
|
+
else:
|
35
|
+
device = torch.device("mps")
|
36
|
+
return device
|
37
|
+
|
38
|
+
|
39
|
+
def resize_image(img, target_resolution=(128, 128)):
|
40
|
+
if type(img) == np.ndarray:
|
41
|
+
img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
|
42
|
+
elif type(img) == torch.Tensor:
|
43
|
+
img = F.interpolate(img, size=target_resolution, mode="bilinear")
|
44
|
+
else:
|
45
|
+
raise ValueError("Unsupported type for img")
|
46
|
+
return img
|
47
|
+
|
48
|
+
|
49
|
+
def save_frames_to_video(frames: list, out_path: str):
|
50
|
+
imgs = []
|
51
|
+
for id, (frame, goal) in enumerate(frames):
|
52
|
+
# if torch.is_tensor(frame):
|
53
|
+
# frame = frame.permute(0, 2, 3, 1).cpu().numpy()
|
54
|
+
|
55
|
+
frame = resize_image(frame, (320, 240)).astype("uint8")
|
56
|
+
cv2.putText(
|
57
|
+
frame,
|
58
|
+
f"FID: {id}",
|
59
|
+
(10, 25),
|
60
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
61
|
+
0.8,
|
62
|
+
(255, 255, 255),
|
63
|
+
2,
|
64
|
+
)
|
65
|
+
cv2.putText(
|
66
|
+
frame,
|
67
|
+
f"Goal: {goal}",
|
68
|
+
(10, 55),
|
69
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
70
|
+
0.8,
|
71
|
+
(255, 0, 0),
|
72
|
+
2,
|
73
|
+
)
|
74
|
+
imgs.append(Image.fromarray(frame))
|
75
|
+
imgs = imgs[::3]
|
76
|
+
imgs[0].save(
|
77
|
+
out_path,
|
78
|
+
save_all=True,
|
79
|
+
append_images=imgs[1:],
|
80
|
+
optimize=False,
|
81
|
+
quality=0,
|
82
|
+
duration=150,
|
83
|
+
loop=0,
|
84
|
+
)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
plancraft/config.py,sha256=k7Ac6Zh6g-Q-fuazPxVdXnSL_PMd3jsIqyiImNaJW5A,4472
|
3
|
+
plancraft/evaluator.py,sha256=zomShgwyUOHNVZbTERpimgEYiyD4JXzHzPi4tg45lsM,9926
|
4
|
+
plancraft/utils.py,sha256=J9K81zBXWCcyIBjwu_eHznnzWamXU-pQrz7BdNTVbLU,2289
|
5
|
+
plancraft/environments/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
plancraft/environments/actions.py,sha256=SeeC9l1cJBs9pdba6BefQ_iQNfFf6FVTWm7HWkacbsY,6262
|
7
|
+
plancraft/environments/env_real.py,sha256=oETMvdq8-TspPoIRxioTDj9EEmuZRuuMxd9c3momvPk,10877
|
8
|
+
plancraft/environments/env_symbolic.py,sha256=ot4IStZ3oT7CYPqdqHsGl2BopLyZUwSF671SrHIiMLk,7777
|
9
|
+
plancraft/environments/items.py,sha256=1R56LyK6tqIssQJMqHst6A9DeEfOX5DN-OBAkumGncw,217
|
10
|
+
plancraft/environments/planner.py,sha256=2B-0aunllmTuiHE7Jn5jHzCg6mMgxZjisiTDdpSKupk,3954
|
11
|
+
plancraft/environments/recipes.py,sha256=nXvOLCRljiZ5IgeevXuosU9IgSs7oQWQJFiuyRNSVFs,19571
|
12
|
+
plancraft/environments/sampler.py,sha256=ZBYoENKdQQ7wbAyVk-c9UNRWKPE0omv9he8c8QZ6wXg,7625
|
13
|
+
plancraft/models/__init__.py,sha256=PasK3jpbhpD0kxF4iHukcccZqvZg6lL240zie3DfLDY,622
|
14
|
+
plancraft/models/act.py,sha256=jdZunT7FcbHvcaJZ_wUDLSuObHjU6JZybghR_B0QJ8Q,6548
|
15
|
+
plancraft/models/base.py,sha256=fFM2BV9PqvIFtUlTz8iz5HPemYRy3S0EituM1XdJJSQ,4927
|
16
|
+
plancraft/models/bbox_model.py,sha256=CoX-odH59S-djkPOH2ViEmbYWo1sefmHiOcBlFWiAkg,16814
|
17
|
+
plancraft/models/dummy.py,sha256=QjxTIiKsWSmhUMAuw7Yy-OKKCLi_x3rwll4hH7ZNXso,1732
|
18
|
+
plancraft/models/generators.py,sha256=c2avYUoXAAWOHRFq24ZYqHqNgoPs_L_XWJrR_TKFg9E,17358
|
19
|
+
plancraft/models/oam.py,sha256=_TpxsaCnGvQ7YJC6KWcpl2HzgBiQaHn7U1zRAXfdjOo,9943
|
20
|
+
plancraft/models/oracle.py,sha256=xRplR2cCW_39UGYtSjDLRDmp0UILhtIXYjuRJ-jokJ8,9598
|
21
|
+
plancraft/models/prompts.py,sha256=XwoRqd_5_VfCUXb10dCRFYXgw70mO2VoQocn3Z2zgs0,6165
|
22
|
+
plancraft/models/react.py,sha256=5yM4tv7tDfm_-5yUSAcw5C4wioWijF9fLEODbFqFDvg,3346
|
23
|
+
plancraft/models/utils.py,sha256=osKX0_uux9wzqYzq1ST0Cu5idrAnyfNvXrj0uO1eKo0,9424
|
24
|
+
plancraft/models/few_shot_images/__init__.py,sha256=nIkyB6w3ok-h4lfJKsrcMQzQF624Y9uxYV1FqCu3Lx0,351
|
25
|
+
plancraft/train/dataset.py,sha256=NrZjbIkosui1kaq7AIWSrYvvzrDxu_njH7FmGKY3xnI,5434
|
26
|
+
plancraft-0.1.3.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
27
|
+
plancraft-0.1.3.dist-info/METADATA,sha256=XLudjSvyfZ-uYJZ1K1RAJUVOQhozFD3FLzwUQeikhK4,2631
|
28
|
+
plancraft-0.1.3.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
29
|
+
plancraft-0.1.3.dist-info/top_level.txt,sha256=yGGA9HtPKH2-pIJFrbyYhj2JF9Xj-4m0fxMPKT9FNzg,10
|
30
|
+
plancraft-0.1.3.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
plancraft
|
plancraft-0.1.2.dist-info/RECORD
DELETED
@@ -1,5 +0,0 @@
|
|
1
|
-
plancraft-0.1.2.dist-info/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
2
|
-
plancraft-0.1.2.dist-info/METADATA,sha256=IDsYHFCvjx4Nhe4Jgyj0x7YsIVAR95ZyrSA-LIQBgeI,2631
|
3
|
-
plancraft-0.1.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
4
|
-
plancraft-0.1.2.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
5
|
-
plancraft-0.1.2.dist-info/RECORD,,
|
@@ -1 +0,0 @@
|
|
1
|
-
|
File without changes
|
File without changes
|