plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
train/dataset.py DELETED
@@ -1,187 +0,0 @@
1
- import glob
2
- import json
3
- import random
4
-
5
- import torch
6
- from torch.utils.data import Dataset
7
-
8
- TEMPLATES = {
9
- "idefics2": {
10
- "assistant": "\nAssistant:",
11
- "user": "\nUser:",
12
- },
13
- "llama3": {
14
- "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n",
15
- "user": "<|start_header_id|>user<|end_header_id|>\n\n",
16
- },
17
- }
18
-
19
-
20
- class PlancraftDialogueDataset(Dataset):
21
- def __init__(
22
- self,
23
- dataset_dir: str = "data/oracle",
24
- use_images=False,
25
- trace_mode="oa",
26
- split="train",
27
- max_message_window=30,
28
- ):
29
- super().__init__()
30
- self.split = split
31
- self.use_images = use_images
32
- self.trace_mode = trace_mode
33
- self.add_system_message = True
34
-
35
- assert trace_mode in ["oa", "ota"], f"Invalid trace mode {trace_mode}"
36
-
37
- print("Loading dialogue dataset")
38
- data = []
39
- for example_path in sorted(
40
- glob.glob(f"{dataset_dir}/{split}/{trace_mode}/*.json")
41
- ):
42
- with open(example_path) as f:
43
- example = json.load(f)
44
- data.append(example)
45
- self.dataset = data
46
- self.max_message_window = max_message_window
47
-
48
- def __len__(self) -> int:
49
- return len(self.dataset)
50
-
51
- def __getitem__(self, idx: int) -> tuple[dict, list]:
52
- example = self.dataset[idx]
53
- if len(example["messages"]) > self.max_message_window:
54
- # add system message
55
- if self.add_system_message:
56
- messages = [example["messages"][0]]
57
- else:
58
- messages = []
59
-
60
- # sample window
61
- user_messages_idxs = list(
62
- range(self.max_message_window, len(example["messages"]), 2)
63
- )
64
- end = random.choice(user_messages_idxs)
65
- start = end - self.max_message_window + 1
66
- assert start != 0
67
- # add window
68
- messages = messages + example["messages"][start : end + 1]
69
-
70
- else:
71
- messages = example["messages"]
72
- return messages
73
-
74
-
75
- def track_assistant_response(
76
- batch,
77
- tokenizer,
78
- template_name: str = "llama3",
79
- ):
80
- """
81
- Mask that returns 1 for tokens in the assistant response and 0 otherwise.
82
- """
83
- assistant_template = TEMPLATES[template_name]["assistant"]
84
- user_template = TEMPLATES[template_name]["user"]
85
- start_seq = tokenizer.encode(
86
- assistant_template,
87
- add_special_tokens=False,
88
- return_tensors="pt",
89
- )[0]
90
- end_seq = tokenizer.encode(
91
- user_template,
92
- add_special_tokens=False,
93
- return_tensors="pt",
94
- )[0]
95
- encoded_label_ids = batch["labels"]
96
- mask = torch.zeros_like(encoded_label_ids)
97
- for seq_idx, seq in enumerate(encoded_label_ids):
98
- in_masked_response = False
99
- i = 0
100
- while i < len(seq):
101
- if i + len(start_seq) < len(seq) and torch.all(
102
- seq[i : i + len(start_seq)].eq(start_seq)
103
- ):
104
- in_masked_response = True
105
- i += len(start_seq)
106
- continue
107
- if i + len(end_seq) < len(seq) and torch.all(
108
- seq[i : i + len(end_seq)].eq(end_seq)
109
- ):
110
- in_masked_response = False
111
- i += len(end_seq)
112
- continue
113
- if in_masked_response:
114
- mask[seq_idx, i] = 1
115
- else:
116
- mask[seq_idx, i] = 0
117
- i += 1
118
- return mask
119
-
120
-
121
- def get_collate_fn(
122
- tokenizer,
123
- max_length=8142,
124
- only_assistant=False,
125
- template_name: str = "llama3",
126
- ):
127
- assert template_name in TEMPLATES
128
-
129
- def collate_fn(batch):
130
- messages_batch = []
131
- for messages in batch:
132
- text = tokenizer.apply_chat_template(
133
- messages, add_generation_prompt=False, tokenize=False
134
- )
135
- # remove BOS token since it will later be added again by the tokenizer
136
- text = text.replace("<|begin_of_text|>", "")
137
- messages_batch.append(text)
138
-
139
- batch = tokenizer(
140
- messages_batch,
141
- truncation=True,
142
- max_length=max_length,
143
- return_tensors="pt",
144
- )
145
- labels = batch["input_ids"].clone()
146
- batch["labels"] = labels
147
-
148
- # add mask for assistant response
149
- if only_assistant:
150
- mask = track_assistant_response(
151
- batch, tokenizer, template_name=template_name
152
- )
153
- labels[mask == 0] = -100
154
-
155
- return batch
156
-
157
- return collate_fn
158
-
159
-
160
- def get_dataset_and_collate(
161
- tokenizer,
162
- template_name: str = "llama3",
163
- max_length: int = 8142,
164
- max_message_window: int = 30,
165
- trace_mode="oa",
166
- only_assistant=False,
167
- ):
168
- if template_name == "llama3":
169
- train_dataset = PlancraftDialogueDataset(
170
- use_images=False,
171
- max_message_window=max_message_window,
172
- split="train",
173
- trace_mode=trace_mode,
174
- )
175
- val_dataset = PlancraftDialogueDataset(
176
- use_images=False,
177
- max_message_window=max_message_window,
178
- split="val",
179
- trace_mode=trace_mode,
180
- )
181
- collate_fn = get_collate_fn(
182
- tokenizer=tokenizer,
183
- only_assistant=only_assistant,
184
- template_name=template_name,
185
- max_length=max_length,
186
- )
187
- return train_dataset, val_dataset, collate_fn