plancraft 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
train/dataset.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
import glob
|
2
|
+
import json
|
3
|
+
import random
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.utils.data import Dataset
|
7
|
+
|
8
|
+
TEMPLATES = {
|
9
|
+
"idefics2": {
|
10
|
+
"assistant": "\nAssistant:",
|
11
|
+
"user": "\nUser:",
|
12
|
+
},
|
13
|
+
"llama3": {
|
14
|
+
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
15
|
+
"user": "<|start_header_id|>user<|end_header_id|>\n\n",
|
16
|
+
},
|
17
|
+
}
|
18
|
+
|
19
|
+
|
20
|
+
class PlancraftDialogueDataset(Dataset):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
dataset_dir: str = "data/oracle",
|
24
|
+
use_images=False,
|
25
|
+
trace_mode="oa",
|
26
|
+
split="train",
|
27
|
+
max_message_window=30,
|
28
|
+
):
|
29
|
+
super().__init__()
|
30
|
+
self.split = split
|
31
|
+
self.use_images = use_images
|
32
|
+
self.trace_mode = trace_mode
|
33
|
+
self.add_system_message = True
|
34
|
+
|
35
|
+
assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
|
36
|
+
|
37
|
+
print("Loading dialogue dataset")
|
38
|
+
data = []
|
39
|
+
for example_path in sorted(
|
40
|
+
glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
|
41
|
+
):
|
42
|
+
with open(example_path) as f:
|
43
|
+
example = json.load(f)
|
44
|
+
data.append(example)
|
45
|
+
self.dataset = data
|
46
|
+
self.max_message_window = max_message_window
|
47
|
+
|
48
|
+
def __len__(self) -> int:
|
49
|
+
return len(self.dataset)
|
50
|
+
|
51
|
+
def __getitem__(self, idx: int) -> tuple[dict, list]:
|
52
|
+
example = self.dataset[idx]
|
53
|
+
if len(example["messages"]) > self.max_message_window:
|
54
|
+
# add system message
|
55
|
+
if self.add_system_message:
|
56
|
+
messages = [example["messages"][0]]
|
57
|
+
else:
|
58
|
+
messages = []
|
59
|
+
|
60
|
+
# sample window
|
61
|
+
user_messages_idxs = list(
|
62
|
+
range(self.max_message_window, len(example["messages"]), 2)
|
63
|
+
)
|
64
|
+
end = random.choice(user_messages_idxs)
|
65
|
+
start = end - self.max_message_window + 1
|
66
|
+
assert start != 0
|
67
|
+
# add window
|
68
|
+
messages = messages + example["messages"][start : end + 1]
|
69
|
+
|
70
|
+
else:
|
71
|
+
messages = example["messages"]
|
72
|
+
return messages
|
73
|
+
|
74
|
+
|
75
|
+
def track_assistant_response(
|
76
|
+
batch,
|
77
|
+
tokenizer,
|
78
|
+
template_name: str = "llama3",
|
79
|
+
):
|
80
|
+
"""
|
81
|
+
Mask that returns 1 for tokens in the assistant response and 0 otherwise.
|
82
|
+
"""
|
83
|
+
assistant_template = TEMPLATES[template_name]["assistant"]
|
84
|
+
user_template = TEMPLATES[template_name]["user"]
|
85
|
+
start_seq = tokenizer.encode(
|
86
|
+
assistant_template,
|
87
|
+
add_special_tokens=False,
|
88
|
+
return_tensors="pt",
|
89
|
+
)[0]
|
90
|
+
end_seq = tokenizer.encode(
|
91
|
+
user_template,
|
92
|
+
add_special_tokens=False,
|
93
|
+
return_tensors="pt",
|
94
|
+
)[0]
|
95
|
+
encoded_label_ids = batch["labels"]
|
96
|
+
mask = torch.zeros_like(encoded_label_ids)
|
97
|
+
for seq_idx, seq in enumerate(encoded_label_ids):
|
98
|
+
in_masked_response = False
|
99
|
+
i = 0
|
100
|
+
while i < len(seq):
|
101
|
+
if i + len(start_seq) < len(seq) and torch.all(
|
102
|
+
seq[i : i + len(start_seq)].eq(start_seq)
|
103
|
+
):
|
104
|
+
in_masked_response = True
|
105
|
+
i += len(start_seq)
|
106
|
+
continue
|
107
|
+
if i + len(end_seq) < len(seq) and torch.all(
|
108
|
+
seq[i : i + len(end_seq)].eq(end_seq)
|
109
|
+
):
|
110
|
+
in_masked_response = False
|
111
|
+
i += len(end_seq)
|
112
|
+
continue
|
113
|
+
if in_masked_response:
|
114
|
+
mask[seq_idx, i] = 1
|
115
|
+
else:
|
116
|
+
mask[seq_idx, i] = 0
|
117
|
+
i += 1
|
118
|
+
return mask
|
119
|
+
|
120
|
+
|
121
|
+
def get_collate_fn(
|
122
|
+
tokenizer,
|
123
|
+
max_length=8142,
|
124
|
+
only_assistant=False,
|
125
|
+
template_name: str = "llama3",
|
126
|
+
):
|
127
|
+
assert template_name in TEMPLATES
|
128
|
+
|
129
|
+
def collate_fn(batch):
|
130
|
+
messages_batch = []
|
131
|
+
for messages in batch:
|
132
|
+
text = tokenizer.apply_chat_template(
|
133
|
+
messages, add_generation_prompt=False, tokenize=False
|
134
|
+
)
|
135
|
+
# remove BOS token since it will later be added again by the tokenizer
|
136
|
+
text = text.replace("<|begin_of_text|>", "")
|
137
|
+
messages_batch.append(text)
|
138
|
+
|
139
|
+
batch = tokenizer(
|
140
|
+
messages_batch,
|
141
|
+
truncation=True,
|
142
|
+
max_length=max_length,
|
143
|
+
return_tensors="pt",
|
144
|
+
)
|
145
|
+
labels = batch["input_ids"].clone()
|
146
|
+
batch["labels"] = labels
|
147
|
+
|
148
|
+
# add mask for assistant response
|
149
|
+
if only_assistant:
|
150
|
+
mask = track_assistant_response(
|
151
|
+
batch, tokenizer, template_name=template_name
|
152
|
+
)
|
153
|
+
labels[mask == 0] = -100
|
154
|
+
|
155
|
+
return batch
|
156
|
+
|
157
|
+
return collate_fn
|
158
|
+
|
159
|
+
|
160
|
+
def get_dataset_and_collate(
|
161
|
+
tokenizer,
|
162
|
+
template_name: str = "llama3",
|
163
|
+
max_length: int = 8142,
|
164
|
+
max_message_window: int = 30,
|
165
|
+
trace_mode="oa",
|
166
|
+
only_assistant=False,
|
167
|
+
):
|
168
|
+
if template_name == "llama3":
|
169
|
+
train_dataset = PlancraftDialogueDataset(
|
170
|
+
use_images=False,
|
171
|
+
max_message_window=max_message_window,
|
172
|
+
split="train",
|
173
|
+
trace_mode=trace_mode,
|
174
|
+
)
|
175
|
+
val_dataset = PlancraftDialogueDataset(
|
176
|
+
use_images=False,
|
177
|
+
max_message_window=max_message_window,
|
178
|
+
split="val",
|
179
|
+
trace_mode=trace_mode,
|
180
|
+
)
|
181
|
+
collate_fn = get_collate_fn(
|
182
|
+
tokenizer=tokenizer,
|
183
|
+
only_assistant=only_assistant,
|
184
|
+
template_name=template_name,
|
185
|
+
max_length=max_length,
|
186
|
+
)
|
187
|
+
return train_dataset, val_dataset, collate_fn
|