plancraft 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
train/dataset.py ADDED
@@ -0,0 +1,187 @@
1
+ import glob
2
+ import json
3
+ import random
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ TEMPLATES = {
9
+ "idefics2": {
10
+ "assistant": "\nAssistant:",
11
+ "user": "\nUser:",
12
+ },
13
+ "llama3": {
14
+ "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
15
+ "user": "<|start_header_id|>user<|end_header_id|>\n\n",
16
+ },
17
+ }
18
+
19
+
20
+ class PlancraftDialogueDataset(Dataset):
21
+ def __init__(
22
+ self,
23
+ dataset_dir: str = "data/oracle",
24
+ use_images=False,
25
+ trace_mode="oa",
26
+ split="train",
27
+ max_message_window=30,
28
+ ):
29
+ super().__init__()
30
+ self.split = split
31
+ self.use_images = use_images
32
+ self.trace_mode = trace_mode
33
+ self.add_system_message = True
34
+
35
+ assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
36
+
37
+ print("Loading dialogue dataset")
38
+ data = []
39
+ for example_path in sorted(
40
+ glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
41
+ ):
42
+ with open(example_path) as f:
43
+ example = json.load(f)
44
+ data.append(example)
45
+ self.dataset = data
46
+ self.max_message_window = max_message_window
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.dataset)
50
+
51
+ def __getitem__(self, idx: int) -> tuple[dict, list]:
52
+ example = self.dataset[idx]
53
+ if len(example["messages"]) > self.max_message_window:
54
+ # add system message
55
+ if self.add_system_message:
56
+ messages = [example["messages"][0]]
57
+ else:
58
+ messages = []
59
+
60
+ # sample window
61
+ user_messages_idxs = list(
62
+ range(self.max_message_window, len(example["messages"]), 2)
63
+ )
64
+ end = random.choice(user_messages_idxs)
65
+ start = end - self.max_message_window + 1
66
+ assert start != 0
67
+ # add window
68
+ messages = messages + example["messages"][start : end + 1]
69
+
70
+ else:
71
+ messages = example["messages"]
72
+ return messages
73
+
74
+
75
+ def track_assistant_response(
76
+ batch,
77
+ tokenizer,
78
+ template_name: str = "llama3",
79
+ ):
80
+ """
81
+ Mask that returns 1 for tokens in the assistant response and 0 otherwise.
82
+ """
83
+ assistant_template = TEMPLATES[template_name]["assistant"]
84
+ user_template = TEMPLATES[template_name]["user"]
85
+ start_seq = tokenizer.encode(
86
+ assistant_template,
87
+ add_special_tokens=False,
88
+ return_tensors="pt",
89
+ )[0]
90
+ end_seq = tokenizer.encode(
91
+ user_template,
92
+ add_special_tokens=False,
93
+ return_tensors="pt",
94
+ )[0]
95
+ encoded_label_ids = batch["labels"]
96
+ mask = torch.zeros_like(encoded_label_ids)
97
+ for seq_idx, seq in enumerate(encoded_label_ids):
98
+ in_masked_response = False
99
+ i = 0
100
+ while i < len(seq):
101
+ if i + len(start_seq) < len(seq) and torch.all(
102
+ seq[i : i + len(start_seq)].eq(start_seq)
103
+ ):
104
+ in_masked_response = True
105
+ i += len(start_seq)
106
+ continue
107
+ if i + len(end_seq) < len(seq) and torch.all(
108
+ seq[i : i + len(end_seq)].eq(end_seq)
109
+ ):
110
+ in_masked_response = False
111
+ i += len(end_seq)
112
+ continue
113
+ if in_masked_response:
114
+ mask[seq_idx, i] = 1
115
+ else:
116
+ mask[seq_idx, i] = 0
117
+ i += 1
118
+ return mask
119
+
120
+
121
+ def get_collate_fn(
122
+ tokenizer,
123
+ max_length=8142,
124
+ only_assistant=False,
125
+ template_name: str = "llama3",
126
+ ):
127
+ assert template_name in TEMPLATES
128
+
129
+ def collate_fn(batch):
130
+ messages_batch = []
131
+ for messages in batch:
132
+ text = tokenizer.apply_chat_template(
133
+ messages, add_generation_prompt=False, tokenize=False
134
+ )
135
+ # remove BOS token since it will later be added again by the tokenizer
136
+ text = text.replace("<|begin_of_text|>", "")
137
+ messages_batch.append(text)
138
+
139
+ batch = tokenizer(
140
+ messages_batch,
141
+ truncation=True,
142
+ max_length=max_length,
143
+ return_tensors="pt",
144
+ )
145
+ labels = batch["input_ids"].clone()
146
+ batch["labels"] = labels
147
+
148
+ # add mask for assistant response
149
+ if only_assistant:
150
+ mask = track_assistant_response(
151
+ batch, tokenizer, template_name=template_name
152
+ )
153
+ labels[mask == 0] = -100
154
+
155
+ return batch
156
+
157
+ return collate_fn
158
+
159
+
160
+ def get_dataset_and_collate(
161
+ tokenizer,
162
+ template_name: str = "llama3",
163
+ max_length: int = 8142,
164
+ max_message_window: int = 30,
165
+ trace_mode="oa",
166
+ only_assistant=False,
167
+ ):
168
+ if template_name == "llama3":
169
+ train_dataset = PlancraftDialogueDataset(
170
+ use_images=False,
171
+ max_message_window=max_message_window,
172
+ split="train",
173
+ trace_mode=trace_mode,
174
+ )
175
+ val_dataset = PlancraftDialogueDataset(
176
+ use_images=False,
177
+ max_message_window=max_message_window,
178
+ split="val",
179
+ trace_mode=trace_mode,
180
+ )
181
+ collate_fn = get_collate_fn(
182
+ tokenizer=tokenizer,
183
+ only_assistant=only_assistant,
184
+ template_name=template_name,
185
+ max_length=max_length,
186
+ )
187
+ return train_dataset, val_dataset, collate_fn