plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -316
- environments/env_symbolic.py +0 -212
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -480
- models/oam.py +0 -283
- models/oracle.py +0 -265
- models/prompts.py +0 -158
- models/react.py +0 -93
- models/utils.py +0 -289
- plancraft-0.1.1.dist-info/METADATA +0 -74
- plancraft-0.1.1.dist-info/RECORD +0 -26
- plancraft-0.1.1.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
train/dataset.py
DELETED
@@ -1,187 +0,0 @@
|
|
1
|
-
import glob
|
2
|
-
import json
|
3
|
-
import random
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from torch.utils.data import Dataset
|
7
|
-
|
8
|
-
TEMPLATES = {
|
9
|
-
"idefics2": {
|
10
|
-
"assistant": "\nAssistant:",
|
11
|
-
"user": "\nUser:",
|
12
|
-
},
|
13
|
-
"llama3": {
|
14
|
-
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
15
|
-
"user": "<|start_header_id|>user<|end_header_id|>\n\n",
|
16
|
-
},
|
17
|
-
}
|
18
|
-
|
19
|
-
|
20
|
-
class PlancraftDialogueDataset(Dataset):
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
dataset_dir: str = "data/oracle",
|
24
|
-
use_images=False,
|
25
|
-
trace_mode="oa",
|
26
|
-
split="train",
|
27
|
-
max_message_window=30,
|
28
|
-
):
|
29
|
-
super().__init__()
|
30
|
-
self.split = split
|
31
|
-
self.use_images = use_images
|
32
|
-
self.trace_mode = trace_mode
|
33
|
-
self.add_system_message = True
|
34
|
-
|
35
|
-
assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
|
36
|
-
|
37
|
-
print("Loading dialogue dataset")
|
38
|
-
data = []
|
39
|
-
for example_path in sorted(
|
40
|
-
glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
|
41
|
-
):
|
42
|
-
with open(example_path) as f:
|
43
|
-
example = json.load(f)
|
44
|
-
data.append(example)
|
45
|
-
self.dataset = data
|
46
|
-
self.max_message_window = max_message_window
|
47
|
-
|
48
|
-
def __len__(self) -> int:
|
49
|
-
return len(self.dataset)
|
50
|
-
|
51
|
-
def __getitem__(self, idx: int) -> tuple[dict, list]:
|
52
|
-
example = self.dataset[idx]
|
53
|
-
if len(example["messages"]) > self.max_message_window:
|
54
|
-
# add system message
|
55
|
-
if self.add_system_message:
|
56
|
-
messages = [example["messages"][0]]
|
57
|
-
else:
|
58
|
-
messages = []
|
59
|
-
|
60
|
-
# sample window
|
61
|
-
user_messages_idxs = list(
|
62
|
-
range(self.max_message_window, len(example["messages"]), 2)
|
63
|
-
)
|
64
|
-
end = random.choice(user_messages_idxs)
|
65
|
-
start = end - self.max_message_window + 1
|
66
|
-
assert start != 0
|
67
|
-
# add window
|
68
|
-
messages = messages + example["messages"][start : end + 1]
|
69
|
-
|
70
|
-
else:
|
71
|
-
messages = example["messages"]
|
72
|
-
return messages
|
73
|
-
|
74
|
-
|
75
|
-
def track_assistant_response(
|
76
|
-
batch,
|
77
|
-
tokenizer,
|
78
|
-
template_name: str = "llama3",
|
79
|
-
):
|
80
|
-
"""
|
81
|
-
Mask that returns 1 for tokens in the assistant response and 0 otherwise.
|
82
|
-
"""
|
83
|
-
assistant_template = TEMPLATES[template_name]["assistant"]
|
84
|
-
user_template = TEMPLATES[template_name]["user"]
|
85
|
-
start_seq = tokenizer.encode(
|
86
|
-
assistant_template,
|
87
|
-
add_special_tokens=False,
|
88
|
-
return_tensors="pt",
|
89
|
-
)[0]
|
90
|
-
end_seq = tokenizer.encode(
|
91
|
-
user_template,
|
92
|
-
add_special_tokens=False,
|
93
|
-
return_tensors="pt",
|
94
|
-
)[0]
|
95
|
-
encoded_label_ids = batch["labels"]
|
96
|
-
mask = torch.zeros_like(encoded_label_ids)
|
97
|
-
for seq_idx, seq in enumerate(encoded_label_ids):
|
98
|
-
in_masked_response = False
|
99
|
-
i = 0
|
100
|
-
while i < len(seq):
|
101
|
-
if i + len(start_seq) < len(seq) and torch.all(
|
102
|
-
seq[i : i + len(start_seq)].eq(start_seq)
|
103
|
-
):
|
104
|
-
in_masked_response = True
|
105
|
-
i += len(start_seq)
|
106
|
-
continue
|
107
|
-
if i + len(end_seq) < len(seq) and torch.all(
|
108
|
-
seq[i : i + len(end_seq)].eq(end_seq)
|
109
|
-
):
|
110
|
-
in_masked_response = False
|
111
|
-
i += len(end_seq)
|
112
|
-
continue
|
113
|
-
if in_masked_response:
|
114
|
-
mask[seq_idx, i] = 1
|
115
|
-
else:
|
116
|
-
mask[seq_idx, i] = 0
|
117
|
-
i += 1
|
118
|
-
return mask
|
119
|
-
|
120
|
-
|
121
|
-
def get_collate_fn(
|
122
|
-
tokenizer,
|
123
|
-
max_length=8142,
|
124
|
-
only_assistant=False,
|
125
|
-
template_name: str = "llama3",
|
126
|
-
):
|
127
|
-
assert template_name in TEMPLATES
|
128
|
-
|
129
|
-
def collate_fn(batch):
|
130
|
-
messages_batch = []
|
131
|
-
for messages in batch:
|
132
|
-
text = tokenizer.apply_chat_template(
|
133
|
-
messages, add_generation_prompt=False, tokenize=False
|
134
|
-
)
|
135
|
-
# remove BOS token since it will later be added again by the tokenizer
|
136
|
-
text = text.replace("<|begin_of_text|>", "")
|
137
|
-
messages_batch.append(text)
|
138
|
-
|
139
|
-
batch = tokenizer(
|
140
|
-
messages_batch,
|
141
|
-
truncation=True,
|
142
|
-
max_length=max_length,
|
143
|
-
return_tensors="pt",
|
144
|
-
)
|
145
|
-
labels = batch["input_ids"].clone()
|
146
|
-
batch["labels"] = labels
|
147
|
-
|
148
|
-
# add mask for assistant response
|
149
|
-
if only_assistant:
|
150
|
-
mask = track_assistant_response(
|
151
|
-
batch, tokenizer, template_name=template_name
|
152
|
-
)
|
153
|
-
labels[mask == 0] = -100
|
154
|
-
|
155
|
-
return batch
|
156
|
-
|
157
|
-
return collate_fn
|
158
|
-
|
159
|
-
|
160
|
-
def get_dataset_and_collate(
|
161
|
-
tokenizer,
|
162
|
-
template_name: str = "llama3",
|
163
|
-
max_length: int = 8142,
|
164
|
-
max_message_window: int = 30,
|
165
|
-
trace_mode="oa",
|
166
|
-
only_assistant=False,
|
167
|
-
):
|
168
|
-
if template_name == "llama3":
|
169
|
-
train_dataset = PlancraftDialogueDataset(
|
170
|
-
use_images=False,
|
171
|
-
max_message_window=max_message_window,
|
172
|
-
split="train",
|
173
|
-
trace_mode=trace_mode,
|
174
|
-
)
|
175
|
-
val_dataset = PlancraftDialogueDataset(
|
176
|
-
use_images=False,
|
177
|
-
max_message_window=max_message_window,
|
178
|
-
split="val",
|
179
|
-
trace_mode=trace_mode,
|
180
|
-
)
|
181
|
-
collate_fn = get_collate_fn(
|
182
|
-
tokenizer=tokenizer,
|
183
|
-
only_assistant=only_assistant,
|
184
|
-
template_name=template_name,
|
185
|
-
max_length=max_length,
|
186
|
-
)
|
187
|
-
return train_dataset, val_dataset, collate_fn
|
File without changes
|