plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.0.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -315
- environments/env_symbolic.py +0 -215
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -483
- models/oam.py +0 -284
- models/oracle.py +0 -268
- models/prompts.py +0 -158
- models/react.py +0 -98
- models/utils.py +0 -289
- plancraft-0.1.0.dist-info/METADATA +0 -53
- plancraft-0.1.0.dist-info/RECORD +0 -26
- plancraft-0.1.0.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.0.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
train/dataset.py
DELETED
@@ -1,187 +0,0 @@
|
|
1
|
-
import glob
|
2
|
-
import json
|
3
|
-
import random
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from torch.utils.data import Dataset
|
7
|
-
|
8
|
-
TEMPLATES = {
|
9
|
-
"idefics2": {
|
10
|
-
"assistant": "\nAssistant:",
|
11
|
-
"user": "\nUser:",
|
12
|
-
},
|
13
|
-
"llama3": {
|
14
|
-
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
15
|
-
"user": "<|start_header_id|>user<|end_header_id|>\n\n",
|
16
|
-
},
|
17
|
-
}
|
18
|
-
|
19
|
-
|
20
|
-
class PlancraftDialogueDataset(Dataset):
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
dataset_dir: str = "data/oracle",
|
24
|
-
use_images=False,
|
25
|
-
trace_mode="oa",
|
26
|
-
split="train",
|
27
|
-
max_message_window=30,
|
28
|
-
):
|
29
|
-
super().__init__()
|
30
|
-
self.split = split
|
31
|
-
self.use_images = use_images
|
32
|
-
self.trace_mode = trace_mode
|
33
|
-
self.add_system_message = True
|
34
|
-
|
35
|
-
assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
|
36
|
-
|
37
|
-
print("Loading dialogue dataset")
|
38
|
-
data = []
|
39
|
-
for example_path in sorted(
|
40
|
-
glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
|
41
|
-
):
|
42
|
-
with open(example_path) as f:
|
43
|
-
example = json.load(f)
|
44
|
-
data.append(example)
|
45
|
-
self.dataset = data
|
46
|
-
self.max_message_window = max_message_window
|
47
|
-
|
48
|
-
def __len__(self) -> int:
|
49
|
-
return len(self.dataset)
|
50
|
-
|
51
|
-
def __getitem__(self, idx: int) -> tuple[dict, list]:
|
52
|
-
example = self.dataset[idx]
|
53
|
-
if len(example["messages"]) > self.max_message_window:
|
54
|
-
# add system message
|
55
|
-
if self.add_system_message:
|
56
|
-
messages = [example["messages"][0]]
|
57
|
-
else:
|
58
|
-
messages = []
|
59
|
-
|
60
|
-
# sample window
|
61
|
-
user_messages_idxs = list(
|
62
|
-
range(self.max_message_window, len(example["messages"]), 2)
|
63
|
-
)
|
64
|
-
end = random.choice(user_messages_idxs)
|
65
|
-
start = end - self.max_message_window + 1
|
66
|
-
assert start != 0
|
67
|
-
# add window
|
68
|
-
messages = messages + example["messages"][start : end + 1]
|
69
|
-
|
70
|
-
else:
|
71
|
-
messages = example["messages"]
|
72
|
-
return messages
|
73
|
-
|
74
|
-
|
75
|
-
def track_assistant_response(
|
76
|
-
batch,
|
77
|
-
tokenizer,
|
78
|
-
template_name: str = "llama3",
|
79
|
-
):
|
80
|
-
"""
|
81
|
-
Mask that returns 1 for tokens in the assistant response and 0 otherwise.
|
82
|
-
"""
|
83
|
-
assistant_template = TEMPLATES[template_name]["assistant"]
|
84
|
-
user_template = TEMPLATES[template_name]["user"]
|
85
|
-
start_seq = tokenizer.encode(
|
86
|
-
assistant_template,
|
87
|
-
add_special_tokens=False,
|
88
|
-
return_tensors="pt",
|
89
|
-
)[0]
|
90
|
-
end_seq = tokenizer.encode(
|
91
|
-
user_template,
|
92
|
-
add_special_tokens=False,
|
93
|
-
return_tensors="pt",
|
94
|
-
)[0]
|
95
|
-
encoded_label_ids = batch["labels"]
|
96
|
-
mask = torch.zeros_like(encoded_label_ids)
|
97
|
-
for seq_idx, seq in enumerate(encoded_label_ids):
|
98
|
-
in_masked_response = False
|
99
|
-
i = 0
|
100
|
-
while i < len(seq):
|
101
|
-
if i + len(start_seq) < len(seq) and torch.all(
|
102
|
-
seq[i : i + len(start_seq)].eq(start_seq)
|
103
|
-
):
|
104
|
-
in_masked_response = True
|
105
|
-
i += len(start_seq)
|
106
|
-
continue
|
107
|
-
if i + len(end_seq) < len(seq) and torch.all(
|
108
|
-
seq[i : i + len(end_seq)].eq(end_seq)
|
109
|
-
):
|
110
|
-
in_masked_response = False
|
111
|
-
i += len(end_seq)
|
112
|
-
continue
|
113
|
-
if in_masked_response:
|
114
|
-
mask[seq_idx, i] = 1
|
115
|
-
else:
|
116
|
-
mask[seq_idx, i] = 0
|
117
|
-
i += 1
|
118
|
-
return mask
|
119
|
-
|
120
|
-
|
121
|
-
def get_collate_fn(
|
122
|
-
tokenizer,
|
123
|
-
max_length=8142,
|
124
|
-
only_assistant=False,
|
125
|
-
template_name: str = "llama3",
|
126
|
-
):
|
127
|
-
assert template_name in TEMPLATES
|
128
|
-
|
129
|
-
def collate_fn(batch):
|
130
|
-
messages_batch = []
|
131
|
-
for messages in batch:
|
132
|
-
text = tokenizer.apply_chat_template(
|
133
|
-
messages, add_generation_prompt=False, tokenize=False
|
134
|
-
)
|
135
|
-
# remove BOS token since it will later be added again by the tokenizer
|
136
|
-
text = text.replace("<|begin_of_text|>", "")
|
137
|
-
messages_batch.append(text)
|
138
|
-
|
139
|
-
batch = tokenizer(
|
140
|
-
messages_batch,
|
141
|
-
truncation=True,
|
142
|
-
max_length=max_length,
|
143
|
-
return_tensors="pt",
|
144
|
-
)
|
145
|
-
labels = batch["input_ids"].clone()
|
146
|
-
batch["labels"] = labels
|
147
|
-
|
148
|
-
# add mask for assistant response
|
149
|
-
if only_assistant:
|
150
|
-
mask = track_assistant_response(
|
151
|
-
batch, tokenizer, template_name=template_name
|
152
|
-
)
|
153
|
-
labels[mask == 0] = -100
|
154
|
-
|
155
|
-
return batch
|
156
|
-
|
157
|
-
return collate_fn
|
158
|
-
|
159
|
-
|
160
|
-
def get_dataset_and_collate(
|
161
|
-
tokenizer,
|
162
|
-
template_name: str = "llama3",
|
163
|
-
max_length: int = 8142,
|
164
|
-
max_message_window: int = 30,
|
165
|
-
trace_mode="oa",
|
166
|
-
only_assistant=False,
|
167
|
-
):
|
168
|
-
if template_name == "llama3":
|
169
|
-
train_dataset = PlancraftDialogueDataset(
|
170
|
-
use_images=False,
|
171
|
-
max_message_window=max_message_window,
|
172
|
-
split="train",
|
173
|
-
trace_mode=trace_mode,
|
174
|
-
)
|
175
|
-
val_dataset = PlancraftDialogueDataset(
|
176
|
-
use_images=False,
|
177
|
-
max_message_window=max_message_window,
|
178
|
-
split="val",
|
179
|
-
trace_mode=trace_mode,
|
180
|
-
)
|
181
|
-
collate_fn = get_collate_fn(
|
182
|
-
tokenizer=tokenizer,
|
183
|
-
only_assistant=only_assistant,
|
184
|
-
template_name=template_name,
|
185
|
-
max_length=max_length,
|
186
|
-
)
|
187
|
-
return train_dataset, val_dataset, collate_fn
|
File without changes
|