cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import Union
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
from cehrbert.data_generators.hf_data_generator.hf_dataset import (
|
4
4
|
FINETUNING_COLUMNS,
|
5
5
|
apply_cehrbert_dataset_mapping,
|
6
6
|
)
|
7
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
7
8
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
8
9
|
from datasets import Dataset, DatasetDict
|
9
10
|
|
@@ -31,16 +32,25 @@ def create_cehrgpt_pretraining_dataset(
|
|
31
32
|
dataset: Union[Dataset, DatasetDict],
|
32
33
|
cehrgpt_tokenizer: CehrGptTokenizer,
|
33
34
|
data_args: DataTrainingArguments,
|
34
|
-
|
35
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
36
|
+
) -> Union[Dataset, DatasetDict]:
|
35
37
|
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
|
38
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
39
|
+
if not data_args.streaming:
|
40
|
+
if isinstance(dataset, DatasetDict):
|
41
|
+
all_columns = dataset["train"].column_names
|
42
|
+
else:
|
43
|
+
all_columns = dataset.column_names
|
44
|
+
if "visit_concept_ids" in all_columns:
|
45
|
+
dataset.remove_columns(["visit_concept_ids"])
|
36
46
|
dataset = apply_cehrbert_dataset_mapping(
|
37
47
|
dataset,
|
38
48
|
HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
|
39
49
|
num_proc=data_args.preprocessing_num_workers,
|
40
50
|
batch_size=data_args.preprocessing_batch_size,
|
41
51
|
streaming=data_args.streaming,
|
52
|
+
cache_file_collector=cache_file_collector,
|
42
53
|
)
|
43
|
-
|
44
54
|
if not data_args.streaming:
|
45
55
|
if isinstance(dataset, DatasetDict):
|
46
56
|
all_columns = dataset["train"].column_names
|
@@ -56,8 +66,17 @@ def create_cehrgpt_finetuning_dataset(
|
|
56
66
|
dataset: Union[Dataset, DatasetDict],
|
57
67
|
cehrgpt_tokenizer: CehrGptTokenizer,
|
58
68
|
data_args: DataTrainingArguments,
|
59
|
-
|
69
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
70
|
+
) -> Union[Dataset, DatasetDict]:
|
60
71
|
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
|
72
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
73
|
+
if not data_args.streaming:
|
74
|
+
if isinstance(dataset, DatasetDict):
|
75
|
+
all_columns = dataset["train"].column_names
|
76
|
+
else:
|
77
|
+
all_columns = dataset.column_names
|
78
|
+
if "visit_concept_ids" in all_columns:
|
79
|
+
dataset.remove_columns(["visit_concept_ids"])
|
61
80
|
mapping_functions = [
|
62
81
|
HFFineTuningMapping(cehrgpt_tokenizer),
|
63
82
|
]
|
@@ -68,6 +87,7 @@ def create_cehrgpt_finetuning_dataset(
|
|
68
87
|
num_proc=data_args.preprocessing_num_workers,
|
69
88
|
batch_size=data_args.preprocessing_batch_size,
|
70
89
|
streaming=data_args.streaming,
|
90
|
+
cache_file_collector=cache_file_collector,
|
71
91
|
)
|
72
92
|
|
73
93
|
if not data_args.streaming:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import random
|
2
|
-
from typing import Any, Dict
|
2
|
+
from typing import Any, Dict, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import torch
|
@@ -105,9 +105,12 @@ class CehrGptDataCollator:
|
|
105
105
|
self._try_reverse_tensor(self._convert_to_tensor(example["input_ids"]))
|
106
106
|
for example in examples
|
107
107
|
]
|
108
|
+
|
108
109
|
batch_attention_mask = [
|
109
110
|
self._try_reverse_tensor(
|
110
|
-
torch.
|
111
|
+
self._convert_to_tensor(example["attention_mask"]).to(torch.float)
|
112
|
+
if "attention_mask" in example
|
113
|
+
else torch.ones_like(
|
111
114
|
self._convert_to_tensor(example["input_ids"]), dtype=torch.float
|
112
115
|
)
|
113
116
|
)
|
@@ -128,16 +131,40 @@ class CehrGptDataCollator:
|
|
128
131
|
)
|
129
132
|
assert batch["input_ids"].shape[1] <= self.max_length
|
130
133
|
assert batch["attention_mask"].shape[1] <= self.max_length
|
134
|
+
assert batch["attention_mask"].shape[1] == batch["input_ids"].shape[1], (
|
135
|
+
f'batch["attention_mask"].shape[1]: {batch["attention_mask"].shape[1]}, '
|
136
|
+
f'batch["input_ids"].shape[1]: {batch["input_ids"].shape[1]}'
|
137
|
+
)
|
138
|
+
assert batch["input_ids"].max() < self.tokenizer.vocab_size, (
|
139
|
+
f"batch['input_ids'].max(): {batch['input_ids'].max()} must be smaller than "
|
140
|
+
f"self.tokenizer.vocab_size: {self.tokenizer.vocab_size}. "
|
141
|
+
f"batch['input_ids']: {batch['input_ids']} "
|
142
|
+
)
|
131
143
|
|
132
|
-
if
|
133
|
-
|
144
|
+
if "position_ids" in examples[0]:
|
145
|
+
batch_position_ids = [
|
146
|
+
self._try_reverse_tensor(
|
147
|
+
self._convert_to_tensor(example["position_ids"])
|
148
|
+
)
|
149
|
+
for example in examples
|
150
|
+
]
|
151
|
+
# Pad sequences to the max length in the batch
|
152
|
+
batch["position_ids"] = self._try_reverse_tensor(
|
134
153
|
pad_sequence(
|
135
|
-
|
154
|
+
batch_position_ids,
|
136
155
|
batch_first=True,
|
137
|
-
padding_value
|
156
|
+
padding_value=self.max_length,
|
138
157
|
).to(torch.int64)
|
139
158
|
)
|
140
159
|
|
160
|
+
if self.pretraining:
|
161
|
+
batch["labels"] = torch.where(
|
162
|
+
(batch["input_ids"] != self.tokenizer.pad_token_id)
|
163
|
+
& batch["attention_mask"].to(torch.bool),
|
164
|
+
batch["input_ids"],
|
165
|
+
-100,
|
166
|
+
)
|
167
|
+
|
141
168
|
if self.use_sub_time_tokenization:
|
142
169
|
time_token_indicators = torch.isin(batch["input_ids"], self.time_tokens)
|
143
170
|
masked_tokens = batch["input_ids"].clone()
|
@@ -170,7 +197,7 @@ class CehrGptDataCollator:
|
|
170
197
|
if self.include_values:
|
171
198
|
batch_value_indicators = [
|
172
199
|
self._try_reverse_tensor(
|
173
|
-
self._convert_to_tensor(example["value_indicators"])
|
200
|
+
self._convert_to_tensor(example["value_indicators"]).to(torch.bool)
|
174
201
|
)
|
175
202
|
for example in examples
|
176
203
|
]
|
@@ -178,7 +205,6 @@ class CehrGptDataCollator:
|
|
178
205
|
self._try_reverse_tensor(self._convert_to_tensor(example["values"]))
|
179
206
|
for example in examples
|
180
207
|
]
|
181
|
-
|
182
208
|
batch["value_indicators"] = self._try_reverse_tensor(
|
183
209
|
pad_sequence(
|
184
210
|
batch_value_indicators, batch_first=True, padding_value=False
|
@@ -200,41 +226,58 @@ class CehrGptDataCollator:
|
|
200
226
|
batch["value_indicators"], batch["values"].clone(), -100
|
201
227
|
)
|
202
228
|
|
229
|
+
bz = len(examples)
|
203
230
|
if "person_id" in examples[0]:
|
204
|
-
batch["person_id"] =
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
231
|
+
batch["person_id"] = (
|
232
|
+
torch.cat(
|
233
|
+
[
|
234
|
+
self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
|
235
|
+
for example in examples
|
236
|
+
],
|
237
|
+
dim=0,
|
238
|
+
)
|
239
|
+
.to(torch.int32)
|
240
|
+
.reshape(bz, -1)
|
241
|
+
)
|
211
242
|
|
212
243
|
if "index_date" in examples[0]:
|
213
244
|
batch["index_date"] = torch.cat(
|
214
245
|
[
|
215
|
-
|
246
|
+
torch.tensor(example["index_date"], dtype=torch.float64).reshape(
|
247
|
+
-1, 1
|
248
|
+
)
|
216
249
|
for example in examples
|
217
250
|
],
|
218
251
|
dim=0,
|
219
|
-
).
|
252
|
+
).reshape(bz, -1)
|
220
253
|
|
221
254
|
if "age_at_index" in examples[0]:
|
222
|
-
batch["age_at_index"] =
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
255
|
+
batch["age_at_index"] = (
|
256
|
+
torch.cat(
|
257
|
+
[
|
258
|
+
self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
|
259
|
+
for example in examples
|
260
|
+
],
|
261
|
+
dim=0,
|
262
|
+
)
|
263
|
+
.to(torch.float32)
|
264
|
+
.reshape(bz, -1)
|
265
|
+
)
|
229
266
|
|
230
267
|
if "classifier_label" in examples[0]:
|
231
|
-
batch["classifier_label"] =
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
268
|
+
batch["classifier_label"] = (
|
269
|
+
torch.cat(
|
270
|
+
[
|
271
|
+
self._convert_to_tensor(example["classifier_label"]).reshape(
|
272
|
+
-1, 1
|
273
|
+
)
|
274
|
+
for example in examples
|
275
|
+
],
|
276
|
+
dim=0,
|
277
|
+
)
|
278
|
+
.to(torch.float32)
|
279
|
+
.reshape(bz, -1)
|
280
|
+
)
|
238
281
|
|
239
282
|
return batch
|
240
283
|
|
@@ -273,53 +316,69 @@ class CehrGptDataCollator:
|
|
273
316
|
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
274
317
|
return record
|
275
318
|
|
276
|
-
def generate_start_end_index(
|
319
|
+
def generate_start_end_index(
|
320
|
+
self, record: Dict[str, Any], max_length_allowed: Optional[int] = None
|
321
|
+
) -> Dict[str, Any]:
|
277
322
|
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
278
323
|
# concept_ids will be used to for time to event predictions and identifying the visit starts
|
324
|
+
max_length_allowed = (
|
325
|
+
self.max_length if max_length_allowed is None else max_length_allowed
|
326
|
+
)
|
327
|
+
sample_packing = getattr(self, "sample_packing", False)
|
279
328
|
input_ids = record["input_ids"]
|
280
329
|
if isinstance(input_ids, torch.Tensor):
|
281
330
|
input_ids = input_ids.detach().tolist()
|
282
331
|
concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
283
332
|
seq_length = len(record["input_ids"])
|
284
|
-
|
333
|
+
|
334
|
+
# Subtract one for the [END] token when sample_packing is not enabled
|
335
|
+
new_max_length = (
|
336
|
+
max_length_allowed if sample_packing else max_length_allowed - 1
|
337
|
+
)
|
338
|
+
|
339
|
+
if self.include_ttv_prediction:
|
340
|
+
record["time_to_visits"] = torch.concat(
|
341
|
+
[self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
|
342
|
+
)
|
285
343
|
|
286
344
|
# Return the record directly if the actual sequence length is less than the max sequence
|
287
345
|
if seq_length <= new_max_length:
|
288
|
-
|
289
|
-
[
|
290
|
-
self._convert_to_tensor(record["input_ids"]),
|
291
|
-
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
292
|
-
]
|
293
|
-
)
|
294
|
-
if self.include_values:
|
295
|
-
record["value_indicators"] = torch.concat(
|
296
|
-
[
|
297
|
-
self._convert_to_tensor(record["value_indicators"]),
|
298
|
-
self._convert_to_tensor([False]),
|
299
|
-
]
|
300
|
-
).to(torch.bool)
|
301
|
-
record["values"] = torch.concat(
|
302
|
-
[
|
303
|
-
self._convert_to_tensor(record["values"]),
|
304
|
-
self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
|
305
|
-
]
|
306
|
-
)
|
307
|
-
if self.include_ttv_prediction:
|
308
|
-
record["time_to_visits"] = torch.concat(
|
346
|
+
if not sample_packing:
|
347
|
+
record["input_ids"] = torch.concat(
|
309
348
|
[
|
310
|
-
self._convert_to_tensor(
|
311
|
-
|
312
|
-
),
|
313
|
-
self._convert_to_tensor([-100.0]),
|
349
|
+
self._convert_to_tensor(record["input_ids"]),
|
350
|
+
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
314
351
|
]
|
315
352
|
)
|
353
|
+
if self.include_values:
|
354
|
+
record["value_indicators"] = torch.concat(
|
355
|
+
[
|
356
|
+
self._convert_to_tensor(record["value_indicators"]),
|
357
|
+
self._convert_to_tensor([False]),
|
358
|
+
]
|
359
|
+
).to(torch.bool)
|
360
|
+
record["values"] = torch.concat(
|
361
|
+
[
|
362
|
+
self._convert_to_tensor(record["values"]),
|
363
|
+
self._convert_to_tensor(
|
364
|
+
[self.tokenizer.pad_value_token_id]
|
365
|
+
),
|
366
|
+
]
|
367
|
+
)
|
368
|
+
if self.include_ttv_prediction:
|
369
|
+
record["time_to_visits"] = torch.concat(
|
370
|
+
[
|
371
|
+
record["time_to_visits"],
|
372
|
+
self._convert_to_tensor([-100.0]),
|
373
|
+
]
|
374
|
+
)
|
316
375
|
|
317
376
|
return record
|
318
377
|
|
319
378
|
if self.pretraining:
|
320
379
|
# There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
|
321
380
|
# prompt depending on the new starting point
|
322
|
-
if random.random() < 0.5:
|
381
|
+
if random.random() < 0.5 and not sample_packing:
|
323
382
|
start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
|
324
383
|
concept_ids, new_max_length
|
325
384
|
)
|
@@ -351,6 +410,11 @@ class CehrGptDataCollator:
|
|
351
410
|
break
|
352
411
|
|
353
412
|
record["input_ids"] = record["input_ids"][0:end_index]
|
413
|
+
|
414
|
+
# We want to make sure we take the subset of attention_mask in sample packing if this field is available
|
415
|
+
if sample_packing and "attention_mask" in record:
|
416
|
+
record["attention_mask"] = record["attention_mask"][0:end_index]
|
417
|
+
|
354
418
|
if self.include_values:
|
355
419
|
record["value_indicators"] = self._convert_to_tensor(
|
356
420
|
record["value_indicators"][0:end_index]
|
@@ -364,7 +428,7 @@ class CehrGptDataCollator:
|
|
364
428
|
)
|
365
429
|
return record
|
366
430
|
else:
|
367
|
-
if self.include_demographics:
|
431
|
+
if self.include_demographics and not sample_packing:
|
368
432
|
# We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
|
369
433
|
demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
|
370
434
|
concept_ids
|
@@ -427,6 +491,10 @@ class CehrGptDataCollator:
|
|
427
491
|
current_token = record["input_ids"][i]
|
428
492
|
if current_token == self.vs_token_id:
|
429
493
|
record["input_ids"] = record["input_ids"][i:end_index]
|
494
|
+
if sample_packing and "attention_mask" in record:
|
495
|
+
record["attention_mask"] = record["attention_mask"][
|
496
|
+
i:end_index
|
497
|
+
]
|
430
498
|
if self.include_values:
|
431
499
|
record["value_indicators"] = record["value_indicators"][
|
432
500
|
i:end_index
|
@@ -442,6 +510,10 @@ class CehrGptDataCollator:
|
|
442
510
|
# We simply take the last new_max_length number of tokens from the patient sequence
|
443
511
|
if len(record["input_ids"]) > new_max_length:
|
444
512
|
record["input_ids"] = record["input_ids"][-new_max_length:]
|
513
|
+
if sample_packing and "attention_mask" in record:
|
514
|
+
record["attention_mask"] = record["attention_mask"][
|
515
|
+
-new_max_length:
|
516
|
+
]
|
445
517
|
if self.include_values:
|
446
518
|
record["value_indicators"] = record["value_indicators"][
|
447
519
|
-new_max_length:
|
@@ -452,31 +524,135 @@ class CehrGptDataCollator:
|
|
452
524
|
-new_max_length:
|
453
525
|
]
|
454
526
|
|
455
|
-
|
456
|
-
|
457
|
-
[
|
458
|
-
self._convert_to_tensor(record["input_ids"]),
|
459
|
-
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
460
|
-
]
|
461
|
-
)
|
462
|
-
if self.include_values:
|
463
|
-
record["value_indicators"] = torch.concat(
|
464
|
-
[
|
465
|
-
self._convert_to_tensor(record["value_indicators"]),
|
466
|
-
self._convert_to_tensor([False]),
|
467
|
-
]
|
468
|
-
).to(torch.bool)
|
469
|
-
record["values"] = torch.concat(
|
470
|
-
[
|
471
|
-
self._convert_to_tensor(record["values"]),
|
472
|
-
self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
|
473
|
-
]
|
474
|
-
)
|
475
|
-
if self.include_ttv_prediction:
|
476
|
-
record["time_to_visits"] = torch.concat(
|
527
|
+
if not sample_packing:
|
528
|
+
# Finally we add the end token to the end of the sequence
|
529
|
+
record["input_ids"] = torch.concat(
|
477
530
|
[
|
478
|
-
record["
|
479
|
-
self._convert_to_tensor([
|
531
|
+
self._convert_to_tensor(record["input_ids"]),
|
532
|
+
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
480
533
|
]
|
481
534
|
)
|
535
|
+
if self.include_values:
|
536
|
+
record["value_indicators"] = torch.concat(
|
537
|
+
[
|
538
|
+
self._convert_to_tensor(record["value_indicators"]),
|
539
|
+
self._convert_to_tensor([False]),
|
540
|
+
]
|
541
|
+
).to(torch.bool)
|
542
|
+
record["values"] = torch.concat(
|
543
|
+
[
|
544
|
+
self._convert_to_tensor(record["values"]),
|
545
|
+
self._convert_to_tensor(
|
546
|
+
[self.tokenizer.pad_value_token_id]
|
547
|
+
),
|
548
|
+
]
|
549
|
+
)
|
550
|
+
if self.include_ttv_prediction:
|
551
|
+
record["time_to_visits"] = torch.concat(
|
552
|
+
[
|
553
|
+
record["time_to_visits"],
|
554
|
+
self._convert_to_tensor([-100.0]),
|
555
|
+
]
|
556
|
+
)
|
482
557
|
return record
|
558
|
+
|
559
|
+
|
560
|
+
class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
561
|
+
def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
|
562
|
+
self.max_tokens_per_batch = max_tokens
|
563
|
+
self.max_position_embeddings = max_position_embeddings
|
564
|
+
self.sample_packing = True
|
565
|
+
self.add_end_token_in_sample_packing = kwargs.pop(
|
566
|
+
"add_end_token_in_sample_packing", False
|
567
|
+
)
|
568
|
+
super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
|
569
|
+
|
570
|
+
def __call__(self, examples):
|
571
|
+
current_input_ids = []
|
572
|
+
current_attention_mask = []
|
573
|
+
current_position_ids = []
|
574
|
+
current_value_indicators = []
|
575
|
+
current_values = []
|
576
|
+
|
577
|
+
# Demographics
|
578
|
+
current_person_ids = []
|
579
|
+
current_index_dates = []
|
580
|
+
|
581
|
+
# Binary classification inputs
|
582
|
+
current_ages = []
|
583
|
+
current_labels = []
|
584
|
+
|
585
|
+
for idx, example in enumerate(examples):
|
586
|
+
|
587
|
+
# If the sample length exceeds the model's capacity, truncate this example
|
588
|
+
add_end_token = (
|
589
|
+
len(example["input_ids"]) <= self.max_position_embeddings
|
590
|
+
and self.add_end_token_in_sample_packing
|
591
|
+
)
|
592
|
+
|
593
|
+
if len(example["input_ids"]) > self.max_position_embeddings:
|
594
|
+
example = self.generate_start_end_index(
|
595
|
+
example, self.max_position_embeddings
|
596
|
+
)
|
597
|
+
|
598
|
+
input_ids = example["input_ids"]
|
599
|
+
# We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
|
600
|
+
# If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
|
601
|
+
# we only add [PAD] token to the end of the sequence because it's not finished
|
602
|
+
current_input_ids.extend(
|
603
|
+
list(input_ids)
|
604
|
+
+ (
|
605
|
+
[self.tokenizer.end_token_id, self.tokenizer.pad_token_id]
|
606
|
+
if add_end_token
|
607
|
+
else [self.tokenizer.pad_token_id]
|
608
|
+
)
|
609
|
+
)
|
610
|
+
current_attention_mask.extend(
|
611
|
+
np.ones_like(input_ids).tolist() + ([1, 0] if add_end_token else [0])
|
612
|
+
)
|
613
|
+
num_tokens_to_pad = 1 + int(add_end_token)
|
614
|
+
current_position_ids.extend(list(range(len(input_ids) + num_tokens_to_pad)))
|
615
|
+
if self.include_values:
|
616
|
+
current_value_indicators.extend(
|
617
|
+
list(example["value_indicators"]) + [False] * num_tokens_to_pad
|
618
|
+
)
|
619
|
+
current_values.extend(
|
620
|
+
list(example["values"])
|
621
|
+
+ [self.tokenizer.pad_value_token_id] * num_tokens_to_pad
|
622
|
+
)
|
623
|
+
|
624
|
+
if "person_id" in example:
|
625
|
+
current_person_ids.append(example["person_id"])
|
626
|
+
|
627
|
+
if "index_date" in example:
|
628
|
+
current_index_dates.append(example["index_date"])
|
629
|
+
|
630
|
+
if "age_at_index" in example:
|
631
|
+
current_ages.append(example["age_at_index"])
|
632
|
+
|
633
|
+
if "classifier_label" in example:
|
634
|
+
current_labels.append(example["classifier_label"])
|
635
|
+
|
636
|
+
assert (
|
637
|
+
len(current_input_ids) <= self.max_tokens_per_batch
|
638
|
+
), f"the total number of tokens in the packed sequence should be less than { self.max_tokens_per_batch}"
|
639
|
+
packed_example = {
|
640
|
+
"input_ids": current_input_ids,
|
641
|
+
"attention_mask": current_attention_mask,
|
642
|
+
"position_ids": current_position_ids,
|
643
|
+
}
|
644
|
+
if self.include_values:
|
645
|
+
packed_example.update({"value_indicators": current_value_indicators})
|
646
|
+
packed_example.update({"values": current_values})
|
647
|
+
|
648
|
+
if current_labels:
|
649
|
+
packed_example.update(
|
650
|
+
{
|
651
|
+
"person_id": current_person_ids,
|
652
|
+
"index_date": current_index_dates,
|
653
|
+
"age_at_index": current_ages,
|
654
|
+
"classifier_label": current_labels,
|
655
|
+
}
|
656
|
+
)
|
657
|
+
|
658
|
+
return super().__call__([packed_example])
|