plancraft 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
train/dataset.py ADDED
@@ -0,0 +1,187 @@
1
+ import glob
2
+ import json
3
+ import random
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ TEMPLATES = {
9
+ "idefics2": {
10
+ "assistant": "\nAssistant:",
11
+ "user": "\nUser:",
12
+ },
13
+ "llama3": {
14
+ "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
15
+ "user": "<|start_header_id|>user<|end_header_id|>\n\n",
16
+ },
17
+ }
18
+
19
+
20
+ class PlancraftDialogueDataset(Dataset):
21
+ def __init__(
22
+ self,
23
+ dataset_dir: str = "data/oracle",
24
+ use_images=False,
25
+ trace_mode="oa",
26
+ split="train",
27
+ max_message_window=30,
28
+ ):
29
+ super().__init__()
30
+ self.split = split
31
+ self.use_images = use_images
32
+ self.trace_mode = trace_mode
33
+ self.add_system_message = True
34
+
35
+ assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
36
+
37
+ print("Loading dialogue dataset")
38
+ data = []
39
+ for example_path in sorted(
40
+ glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
41
+ ):
42
+ with open(example_path) as f:
43
+ example = json.load(f)
44
+ data.append(example)
45
+ self.dataset = data
46
+ self.max_message_window = max_message_window
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.dataset)
50
+
51
+ def __getitem__(self, idx: int) -> tuple[dict, list]:
52
+ example = self.dataset[idx]
53
+ if len(example["messages"]) > self.max_message_window:
54
+ # add system message
55
+ if self.add_system_message:
56
+ messages = [example["messages"][0]]
57
+ else:
58
+ messages = []
59
+
60
+ # sample window
61
+ user_messages_idxs = list(
62
+ range(self.max_message_window, len(example["messages"]), 2)
63
+ )
64
+ end = random.choice(user_messages_idxs)
65
+ start = end - self.max_message_window + 1
66
+ assert start != 0
67
+ # add window
68
+ messages = messages + example["messages"][start : end + 1]
69
+
70
+ else:
71
+ messages = example["messages"]
72
+ return messages
73
+
74
+
75
+ def track_assistant_response(
76
+ batch,
77
+ tokenizer,
78
+ template_name: str = "llama3",
79
+ ):
80
+ """
81
+ Mask that returns 1 for tokens in the assistant response and 0 otherwise.
82
+ """
83
+ assistant_template = TEMPLATES[template_name]["assistant"]
84
+ user_template = TEMPLATES[template_name]["user"]
85
+ start_seq = tokenizer.encode(
86
+ assistant_template,
87
+ add_special_tokens=False,
88
+ return_tensors="pt",
89
+ )[0]
90
+ end_seq = tokenizer.encode(
91
+ user_template,
92
+ add_special_tokens=False,
93
+ return_tensors="pt",
94
+ )[0]
95
+ encoded_label_ids = batch["labels"]
96
+ mask = torch.zeros_like(encoded_label_ids)
97
+ for seq_idx, seq in enumerate(encoded_label_ids):
98
+ in_masked_response = False
99
+ i = 0
100
+ while i < len(seq):
101
+ if i + len(start_seq) < len(seq) and torch.all(
102
+ seq[i : i + len(start_seq)].eq(start_seq)
103
+ ):
104
+ in_masked_response = True
105
+ i += len(start_seq)
106
+ continue
107
+ if i + len(end_seq) < len(seq) and torch.all(
108
+ seq[i : i + len(end_seq)].eq(end_seq)
109
+ ):
110
+ in_masked_response = False
111
+ i += len(end_seq)
112
+ continue
113
+ if in_masked_response:
114
+ mask[seq_idx, i] = 1
115
+ else:
116
+ mask[seq_idx, i] = 0
117
+ i += 1
118
+ return mask
119
+
120
+
121
+ def get_collate_fn(
122
+ tokenizer,
123
+ max_length=8142,
124
+ only_assistant=False,
125
+ template_name: str = "llama3",
126
+ ):
127
+ assert template_name in TEMPLATES
128
+
129
+ def collate_fn(batch):
130
+ messages_batch = []
131
+ for messages in batch:
132
+ text = tokenizer.apply_chat_template(
133
+ messages, add_generation_prompt=False, tokenize=False
134
+ )
135
+ # remove BOS token since it will later be added again by the tokenizer
136
+ text = text.replace("<|begin_of_text|>", "")
137
+ messages_batch.append(text)
138
+
139
+ batch = tokenizer(
140
+ messages_batch,
141
+ truncation=True,
142
+ max_length=max_length,
143
+ return_tensors="pt",
144
+ )
145
+ labels = batch["input_ids"].clone()
146
+ batch["labels"] = labels
147
+
148
+ # add mask for assistant response
149
+ if only_assistant:
150
+ mask = track_assistant_response(
151
+ batch, tokenizer, template_name=template_name
152
+ )
153
+ labels[mask == 0] = -100
154
+
155
+ return batch
156
+
157
+ return collate_fn
158
+
159
+
160
+ def get_dataset_and_collate(
161
+ tokenizer,
162
+ template_name: str = "llama3",
163
+ max_length: int = 8142,
164
+ max_message_window: int = 30,
165
+ trace_mode="oa",
166
+ only_assistant=False,
167
+ ):
168
+ if template_name == "llama3":
169
+ train_dataset = PlancraftDialogueDataset(
170
+ use_images=False,
171
+ max_message_window=max_message_window,
172
+ split="train",
173
+ trace_mode=trace_mode,
174
+ )
175
+ val_dataset = PlancraftDialogueDataset(
176
+ use_images=False,
177
+ max_message_window=max_message_window,
178
+ split="val",
179
+ trace_mode=trace_mode,
180
+ )
181
+ collate_fn = get_collate_fn(
182
+ tokenizer=tokenizer,
183
+ only_assistant=only_assistant,
184
+ template_name=template_name,
185
+ max_length=max_length,
186
+ )
187
+ return train_dataset, val_dataset, collate_fn