cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,28 @@
|
|
1
|
+
import copy
|
1
2
|
import random
|
2
|
-
from typing import Any, Dict, Optional
|
3
|
+
from typing import Any, Dict, List, Optional
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import torch
|
6
7
|
from torch.nn.utils.rnn import pad_sequence
|
8
|
+
from transformers.utils import logging
|
7
9
|
|
8
10
|
from cehrgpt.gpt_utils import (
|
9
11
|
DEMOGRAPHIC_PROMPT_SIZE,
|
10
12
|
collect_demographic_prompts_at_visits,
|
11
13
|
extract_time_interval_in_days,
|
14
|
+
extract_time_interval_in_hours,
|
12
15
|
is_att_token,
|
13
16
|
is_inpatient_att_token,
|
17
|
+
is_inpatient_hour_token,
|
18
|
+
is_visit_end,
|
14
19
|
random_slice_gpt_sequence,
|
15
20
|
)
|
16
21
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
17
22
|
|
23
|
+
TIME_TO_EVENT_MAX_TIME = 3650
|
18
24
|
INPATIENT_STAY_DURATION_LIMIT = 30
|
25
|
+
LOG = logging.get_logger("transformers")
|
19
26
|
|
20
27
|
|
21
28
|
class CehrGptDataCollator:
|
@@ -27,20 +34,18 @@ class CehrGptDataCollator:
|
|
27
34
|
include_values: bool = False,
|
28
35
|
include_ttv_prediction: bool = False,
|
29
36
|
use_sub_time_tokenization: bool = False,
|
37
|
+
include_motor_time_to_event: bool = False,
|
38
|
+
motor_tte_vocab_size: int = 0,
|
39
|
+
motor_num_time_pieces: int = 8,
|
30
40
|
pretraining: bool = True,
|
31
41
|
include_demographics: bool = False,
|
42
|
+
add_linear_prob_token: bool = False,
|
32
43
|
):
|
33
44
|
self.tokenizer = tokenizer
|
34
45
|
self.max_length = max_length
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
self.vs_token_id = tokenizer._convert_token_to_id("VS")
|
39
|
-
if self.vs_token_id == tokenizer._oov_token_id:
|
40
|
-
self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
|
41
|
-
self.ve_token_id = tokenizer._convert_token_to_id("VE")
|
42
|
-
if self.ve_token_id == tokenizer._oov_token_id:
|
43
|
-
self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
|
46
|
+
|
47
|
+
self.vs_token_id = tokenizer.vs_token_id
|
48
|
+
self.ve_token_id = tokenizer.ve_token_id
|
44
49
|
|
45
50
|
self.shuffle_records = shuffle_records
|
46
51
|
self.include_values = include_values
|
@@ -48,6 +53,20 @@ class CehrGptDataCollator:
|
|
48
53
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
49
54
|
self.pretraining = pretraining
|
50
55
|
self.include_demographics = include_demographics
|
56
|
+
self.add_linear_prob_token = add_linear_prob_token
|
57
|
+
|
58
|
+
# MOTOR TTE configuration
|
59
|
+
if include_motor_time_to_event:
|
60
|
+
assert motor_tte_vocab_size > 0, (
|
61
|
+
f"motor_tte_vocab_size must be greater than 0 "
|
62
|
+
f"when include_motor_time_to_event is set to True. "
|
63
|
+
f"But motor_tte_vocab_size: {motor_tte_vocab_size} is provided"
|
64
|
+
)
|
65
|
+
|
66
|
+
self.include_motor_time_to_event = include_motor_time_to_event
|
67
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
68
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
69
|
+
self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
|
51
70
|
|
52
71
|
if self.use_sub_time_tokenization:
|
53
72
|
token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
|
@@ -88,6 +107,8 @@ class CehrGptDataCollator:
|
|
88
107
|
):
|
89
108
|
return -100
|
90
109
|
return time_to_visit
|
110
|
+
elif is_inpatient_hour_token(c):
|
111
|
+
return extract_time_interval_in_hours(c) / 24
|
91
112
|
return -100
|
92
113
|
except ValueError:
|
93
114
|
return -100
|
@@ -95,8 +116,8 @@ class CehrGptDataCollator:
|
|
95
116
|
return [float(default_value(_)) for _ in concept_ids]
|
96
117
|
|
97
118
|
def __call__(self, examples):
|
98
|
-
|
99
|
-
examples = [self.generate_start_end_index(_) for _ in examples]
|
119
|
+
sample_packing = getattr(self, "sample_packing", False)
|
120
|
+
examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
|
100
121
|
examples = [self.random_sort(_) for _ in examples]
|
101
122
|
batch = {}
|
102
123
|
|
@@ -141,6 +162,22 @@ class CehrGptDataCollator:
|
|
141
162
|
f"batch['input_ids']: {batch['input_ids']} "
|
142
163
|
)
|
143
164
|
|
165
|
+
if "epoch_times" in examples[0]:
|
166
|
+
batch_epoch_times = [
|
167
|
+
self._try_reverse_tensor(
|
168
|
+
self._convert_to_tensor(example["epoch_times"])
|
169
|
+
)
|
170
|
+
for example in examples
|
171
|
+
]
|
172
|
+
# Pad sequences to the max length in the batch
|
173
|
+
batch["epoch_times"] = self._try_reverse_tensor(
|
174
|
+
pad_sequence(
|
175
|
+
batch_epoch_times,
|
176
|
+
batch_first=True,
|
177
|
+
padding_value=0,
|
178
|
+
).to(torch.float32)
|
179
|
+
)
|
180
|
+
|
144
181
|
if "position_ids" in examples[0]:
|
145
182
|
batch_position_ids = [
|
146
183
|
self._try_reverse_tensor(
|
@@ -153,7 +190,7 @@ class CehrGptDataCollator:
|
|
153
190
|
pad_sequence(
|
154
191
|
batch_position_ids,
|
155
192
|
batch_first=True,
|
156
|
-
padding_value=
|
193
|
+
padding_value=0,
|
157
194
|
).to(torch.int64)
|
158
195
|
)
|
159
196
|
|
@@ -194,6 +231,126 @@ class CehrGptDataCollator:
|
|
194
231
|
)
|
195
232
|
)
|
196
233
|
|
234
|
+
if self.include_motor_time_to_event:
|
235
|
+
examples_with_motor_tte = [
|
236
|
+
self.create_time_to_event_labels(_) for _ in examples
|
237
|
+
]
|
238
|
+
batch_motor_time_to_event_vectors = [
|
239
|
+
self._try_reverse_tensor(
|
240
|
+
self._convert_to_tensor(example["time_to_event_vectors"])
|
241
|
+
)
|
242
|
+
for example in examples_with_motor_tte
|
243
|
+
]
|
244
|
+
batch_motor_event_indicators = [
|
245
|
+
self._try_reverse_tensor(
|
246
|
+
self._convert_to_tensor(example["event_indicators"])
|
247
|
+
)
|
248
|
+
for example in examples_with_motor_tte
|
249
|
+
]
|
250
|
+
batch_motor_time_to_event_to_include = [
|
251
|
+
self._try_reverse_tensor(
|
252
|
+
self._convert_to_tensor(example["time_to_event_to_include"])
|
253
|
+
)
|
254
|
+
for example in examples_with_motor_tte
|
255
|
+
]
|
256
|
+
batch_motor_time_indicators = [
|
257
|
+
self._try_reverse_tensor(
|
258
|
+
self._convert_to_tensor(example["time_indicators"])
|
259
|
+
)
|
260
|
+
for example in examples_with_motor_tte
|
261
|
+
]
|
262
|
+
|
263
|
+
batch_motor_time_to_event_vectors = torch.concat(
|
264
|
+
batch_motor_time_to_event_vectors, dim=0
|
265
|
+
).to(torch.float32)
|
266
|
+
|
267
|
+
# If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
|
268
|
+
# we only create the labels when any example has more than one visit
|
269
|
+
if batch_motor_time_to_event_vectors.dim() <= 1:
|
270
|
+
LOG.warning(
|
271
|
+
"There are no MOTOR TTE labels generated for this batch "
|
272
|
+
"because every example in this batch only contains one visit."
|
273
|
+
)
|
274
|
+
else:
|
275
|
+
batch_size = len(examples)
|
276
|
+
length, num_time_pieces, motor_tte_vocab_size = (
|
277
|
+
batch_motor_time_to_event_vectors.shape
|
278
|
+
)
|
279
|
+
padded_length = batch_size - length % batch_size
|
280
|
+
batch["motor_time_to_event_vectors"] = (
|
281
|
+
torch.concat(
|
282
|
+
[
|
283
|
+
batch_motor_time_to_event_vectors,
|
284
|
+
torch.full(
|
285
|
+
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
286
|
+
0.0,
|
287
|
+
),
|
288
|
+
],
|
289
|
+
dim=0,
|
290
|
+
)
|
291
|
+
.reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
|
292
|
+
.to(torch.float32)
|
293
|
+
)
|
294
|
+
|
295
|
+
# Motor event indicators that indicate there is an event occurred in this time interval
|
296
|
+
batch_motor_event_indicators = torch.concat(
|
297
|
+
batch_motor_event_indicators, dim=0
|
298
|
+
).to(torch.bool)
|
299
|
+
batch["motor_event_indicators"] = (
|
300
|
+
torch.concat(
|
301
|
+
[
|
302
|
+
batch_motor_event_indicators,
|
303
|
+
torch.full(
|
304
|
+
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
305
|
+
False,
|
306
|
+
),
|
307
|
+
],
|
308
|
+
dim=0,
|
309
|
+
)
|
310
|
+
.reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
|
311
|
+
.to(torch.bool)
|
312
|
+
)
|
313
|
+
|
314
|
+
# Input to indicate whether the visit should be included for TTE predictions
|
315
|
+
batch_motor_time_to_event_to_include = torch.concat(
|
316
|
+
batch_motor_time_to_event_to_include, dim=0
|
317
|
+
).to(torch.bool)
|
318
|
+
batch["motor_time_to_event_to_include"] = (
|
319
|
+
torch.concat(
|
320
|
+
[
|
321
|
+
batch_motor_time_to_event_to_include,
|
322
|
+
torch.full((padded_length,), False),
|
323
|
+
],
|
324
|
+
dim=0,
|
325
|
+
).to(torch.bool)
|
326
|
+
).reshape((batch_size, -1))
|
327
|
+
|
328
|
+
# Motor time indicators that indicate whether there are neither clinical events nor censor events
|
329
|
+
batch_motor_time_indicators = torch.concat(
|
330
|
+
batch_motor_time_indicators, dim=0
|
331
|
+
).to(torch.bool)
|
332
|
+
batch["motor_time_indicators"] = (
|
333
|
+
torch.concat(
|
334
|
+
[
|
335
|
+
batch_motor_time_indicators,
|
336
|
+
torch.full(
|
337
|
+
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
338
|
+
False,
|
339
|
+
),
|
340
|
+
],
|
341
|
+
dim=0,
|
342
|
+
)
|
343
|
+
.reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
|
344
|
+
.to(torch.bool)
|
345
|
+
)
|
346
|
+
|
347
|
+
batch["motor_end_index"] = torch.concat(
|
348
|
+
[
|
349
|
+
torch.full((length, 1), 1, dtype=torch.int32),
|
350
|
+
torch.full((padded_length, 1), 0, dtype=torch.int32),
|
351
|
+
]
|
352
|
+
).reshape((batch_size, -1))
|
353
|
+
|
197
354
|
if self.include_values:
|
198
355
|
batch_value_indicators = [
|
199
356
|
self._try_reverse_tensor(
|
@@ -281,6 +438,193 @@ class CehrGptDataCollator:
|
|
281
438
|
|
282
439
|
return batch
|
283
440
|
|
441
|
+
def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
442
|
+
"""
|
443
|
+
Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
|
444
|
+
|
445
|
+
Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
|
446
|
+
to the occurrence of future motor-related events.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
|
450
|
+
This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
|
451
|
+
|
452
|
+
Returns:
|
453
|
+
Dict[str, Any]: The updated input record with added keys:
|
454
|
+
- "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
|
455
|
+
- "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
|
456
|
+
"""
|
457
|
+
input_ids = record["input_ids"]
|
458
|
+
sample_packing = getattr(self, "sample_packing", False)
|
459
|
+
|
460
|
+
if isinstance(input_ids, torch.Tensor):
|
461
|
+
input_ids = input_ids.detach().tolist()
|
462
|
+
|
463
|
+
# This potentially contains packed samples, we need to handle that
|
464
|
+
packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
465
|
+
pad_indices = []
|
466
|
+
if sample_packing:
|
467
|
+
# We start from the first index
|
468
|
+
for i in range(len(packed_concept_ids)):
|
469
|
+
if packed_concept_ids[i] == self.tokenizer.pad_token:
|
470
|
+
# If we encounter consecutive pads, we should break out of the loop
|
471
|
+
if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
|
472
|
+
break
|
473
|
+
pad_indices.append(i)
|
474
|
+
|
475
|
+
# If we did not find a pad, that means the whole sequence belongs to one sample
|
476
|
+
if len(pad_indices) == 0:
|
477
|
+
pad_indices.append(len(packed_concept_ids))
|
478
|
+
|
479
|
+
timepiece_time_to_event_vectors = []
|
480
|
+
timepiece_event_indicators = []
|
481
|
+
timepiece_indicators = []
|
482
|
+
time_to_event_to_includes = []
|
483
|
+
|
484
|
+
for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
|
485
|
+
concept_ids = packed_concept_ids[start_index:end_index]
|
486
|
+
if concept_ids[0] == self.tokenizer.pad_token:
|
487
|
+
concept_ids.pop(0)
|
488
|
+
time_to_event_vectors = []
|
489
|
+
global_event_indicators = []
|
490
|
+
|
491
|
+
# First collect TTE data in reverse chronological order
|
492
|
+
censor_times = []
|
493
|
+
time_to_event_data: List[Dict[str, int]] = []
|
494
|
+
time_to_event_dict: Dict[str, int] = {}
|
495
|
+
time_to_event_to_include: List[bool] = []
|
496
|
+
next_future_visit_concepts = set()
|
497
|
+
time_interval = 0
|
498
|
+
|
499
|
+
# Reverse walk through concept_ids to calculate TTE from each [VE] point
|
500
|
+
for concept_id in reversed(concept_ids):
|
501
|
+
if is_visit_end(concept_id):
|
502
|
+
# Update TTE for existing concepts, or add new ones seen in this visit
|
503
|
+
for existing_concept_id in list(time_to_event_dict.keys()):
|
504
|
+
if existing_concept_id in next_future_visit_concepts:
|
505
|
+
time_to_event_dict[existing_concept_id] = time_interval
|
506
|
+
else:
|
507
|
+
time_to_event_dict[existing_concept_id] += time_interval
|
508
|
+
|
509
|
+
for next_concept_id in next_future_visit_concepts:
|
510
|
+
if next_concept_id not in time_to_event_dict:
|
511
|
+
time_to_event_dict[next_concept_id] = time_interval
|
512
|
+
|
513
|
+
# If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
|
514
|
+
# previous visit
|
515
|
+
time_to_event_to_include.append(time_interval > 0)
|
516
|
+
time_to_event_data.append(copy.deepcopy(time_to_event_dict))
|
517
|
+
# Record the censor time at the end of the visit
|
518
|
+
if censor_times:
|
519
|
+
censor_times.append(censor_times[-1] + time_interval)
|
520
|
+
else:
|
521
|
+
censor_times.append(time_interval)
|
522
|
+
time_interval = 0
|
523
|
+
next_future_visit_concepts.clear()
|
524
|
+
|
525
|
+
elif is_att_token(concept_id):
|
526
|
+
time_interval += extract_time_interval_in_days(concept_id)
|
527
|
+
|
528
|
+
elif self.tokenizer.is_motor_time_to_event_code(concept_id):
|
529
|
+
next_future_visit_concepts.add(concept_id)
|
530
|
+
|
531
|
+
if len(time_to_event_data) == 0:
|
532
|
+
LOG.info(
|
533
|
+
"Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
|
534
|
+
"It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
|
535
|
+
len(concept_ids),
|
536
|
+
concept_ids[-10:],
|
537
|
+
)
|
538
|
+
continue
|
539
|
+
|
540
|
+
# Reverse back to chronological order for final labels
|
541
|
+
time_to_event_data.reverse()
|
542
|
+
censor_times.reverse()
|
543
|
+
time_to_event_to_include.reverse()
|
544
|
+
|
545
|
+
for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
|
546
|
+
time_to_event_vector = np.full(
|
547
|
+
self.tokenizer.motor_tte_vocab_size,
|
548
|
+
fill_value=censor_time,
|
549
|
+
dtype=np.int32,
|
550
|
+
)
|
551
|
+
event_indicator = np.zeros(
|
552
|
+
self.tokenizer.motor_tte_vocab_size,
|
553
|
+
dtype=np.int32,
|
554
|
+
)
|
555
|
+
visit_token_ids = [
|
556
|
+
self.tokenizer.get_motor_token_id(concept_id)
|
557
|
+
for concept_id in visit_tte_data.keys()
|
558
|
+
]
|
559
|
+
visit_tte_values = list(visit_tte_data.values())
|
560
|
+
|
561
|
+
time_to_event_vector[visit_token_ids] = visit_tte_values
|
562
|
+
event_indicator[visit_token_ids] = 1 # not censored (event occurred)
|
563
|
+
|
564
|
+
time_to_event_vectors.append(time_to_event_vector)
|
565
|
+
global_event_indicators.append(event_indicator)
|
566
|
+
|
567
|
+
time_to_event_vectors = np.asarray(time_to_event_vectors)
|
568
|
+
global_event_indicators = np.asarray(global_event_indicators).astype(bool)
|
569
|
+
n_visits = len(time_to_event_vectors)
|
570
|
+
|
571
|
+
timepiece_time_to_event_vector = np.full(
|
572
|
+
(
|
573
|
+
self.motor_num_time_pieces,
|
574
|
+
n_visits,
|
575
|
+
self.tokenizer.motor_tte_vocab_size,
|
576
|
+
),
|
577
|
+
fill_value=0,
|
578
|
+
dtype=np.int32,
|
579
|
+
)
|
580
|
+
timepiece_event_indicator = np.zeros(
|
581
|
+
(
|
582
|
+
self.motor_num_time_pieces,
|
583
|
+
n_visits,
|
584
|
+
self.tokenizer.motor_tte_vocab_size,
|
585
|
+
),
|
586
|
+
dtype=bool,
|
587
|
+
)
|
588
|
+
timepiece_indicator = np.zeros(
|
589
|
+
(
|
590
|
+
self.motor_num_time_pieces,
|
591
|
+
n_visits,
|
592
|
+
self.tokenizer.motor_tte_vocab_size,
|
593
|
+
),
|
594
|
+
dtype=bool,
|
595
|
+
)
|
596
|
+
|
597
|
+
# Putting the event time and censor time into the corresponding time bins
|
598
|
+
for bin_num in range(self.motor_num_time_pieces):
|
599
|
+
start = self.motor_time_interval * bin_num
|
600
|
+
end = self.motor_time_interval * (bin_num + 1)
|
601
|
+
time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
|
602
|
+
timepiece_time_to_event_vector[bin_num] = time_in_bin
|
603
|
+
event_indicator = (
|
604
|
+
global_event_indicators
|
605
|
+
& (start <= time_to_event_vectors)
|
606
|
+
& (time_to_event_vectors < end)
|
607
|
+
)
|
608
|
+
timepiece_event_indicator[bin_num] = event_indicator
|
609
|
+
timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
|
610
|
+
|
611
|
+
timepiece_time_to_event_vectors.append(
|
612
|
+
timepiece_time_to_event_vector.swapaxes(0, 1)
|
613
|
+
)
|
614
|
+
timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
|
615
|
+
timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
|
616
|
+
time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
|
617
|
+
|
618
|
+
record["time_to_event_vectors"] = np.concatenate(
|
619
|
+
timepiece_time_to_event_vectors, axis=0
|
620
|
+
)
|
621
|
+
record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
|
622
|
+
record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
|
623
|
+
record["time_to_event_to_include"] = np.concatenate(
|
624
|
+
time_to_event_to_includes, axis=0
|
625
|
+
)
|
626
|
+
return record
|
627
|
+
|
284
628
|
def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
285
629
|
|
286
630
|
if not self.shuffle_records:
|
@@ -317,14 +661,16 @@ class CehrGptDataCollator:
|
|
317
661
|
return record
|
318
662
|
|
319
663
|
def generate_start_end_index(
|
320
|
-
self,
|
664
|
+
self,
|
665
|
+
record: Dict[str, Any],
|
666
|
+
sample_packing: bool,
|
667
|
+
max_length_allowed: Optional[int] = None,
|
321
668
|
) -> Dict[str, Any]:
|
322
669
|
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
323
670
|
# concept_ids will be used to for time to event predictions and identifying the visit starts
|
324
671
|
max_length_allowed = (
|
325
672
|
self.max_length if max_length_allowed is None else max_length_allowed
|
326
673
|
)
|
327
|
-
sample_packing = getattr(self, "sample_packing", False)
|
328
674
|
input_ids = record["input_ids"]
|
329
675
|
if isinstance(input_ids, torch.Tensor):
|
330
676
|
input_ids = input_ids.detach().tolist()
|
@@ -333,7 +679,9 @@ class CehrGptDataCollator:
|
|
333
679
|
|
334
680
|
# Subtract one for the [END] token when sample_packing is not enabled
|
335
681
|
new_max_length = (
|
336
|
-
max_length_allowed
|
682
|
+
max_length_allowed - 1
|
683
|
+
if not sample_packing and self.pretraining
|
684
|
+
else max_length_allowed
|
337
685
|
)
|
338
686
|
|
339
687
|
if self.include_ttv_prediction:
|
@@ -341,15 +689,34 @@ class CehrGptDataCollator:
|
|
341
689
|
[self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
|
342
690
|
)
|
343
691
|
|
692
|
+
# If linear token exists, we will use it, otherwise we default to the OOV token
|
693
|
+
linear_token_id = (
|
694
|
+
self.tokenizer.linear_token_id
|
695
|
+
if self.tokenizer.linear_token_id
|
696
|
+
else self.tokenizer.oov_token_id
|
697
|
+
)
|
698
|
+
eos_token = (
|
699
|
+
linear_token_id
|
700
|
+
if self.add_linear_prob_token
|
701
|
+
else self.tokenizer.end_token_id
|
702
|
+
)
|
703
|
+
|
344
704
|
# Return the record directly if the actual sequence length is less than the max sequence
|
345
705
|
if seq_length <= new_max_length:
|
346
|
-
if not sample_packing:
|
706
|
+
if not sample_packing and self.pretraining:
|
347
707
|
record["input_ids"] = torch.concat(
|
348
708
|
[
|
349
709
|
self._convert_to_tensor(record["input_ids"]),
|
350
|
-
self._convert_to_tensor([
|
710
|
+
self._convert_to_tensor([eos_token]),
|
351
711
|
]
|
352
712
|
)
|
713
|
+
if "epoch_times" in record:
|
714
|
+
record["epoch_times"] = torch.concat(
|
715
|
+
[
|
716
|
+
self._convert_to_tensor(record["epoch_times"]),
|
717
|
+
self._convert_to_tensor([record["epoch_times"][-1]]),
|
718
|
+
]
|
719
|
+
)
|
353
720
|
if self.include_values:
|
354
721
|
record["value_indicators"] = torch.concat(
|
355
722
|
[
|
@@ -372,7 +739,6 @@ class CehrGptDataCollator:
|
|
372
739
|
self._convert_to_tensor([-100.0]),
|
373
740
|
]
|
374
741
|
)
|
375
|
-
|
376
742
|
return record
|
377
743
|
|
378
744
|
if self.pretraining:
|
@@ -386,6 +752,10 @@ class CehrGptDataCollator:
|
|
386
752
|
record["input_ids"] = self._convert_to_tensor(
|
387
753
|
record["input_ids"][start_index : end_index + 1]
|
388
754
|
)
|
755
|
+
if "epoch_times" in record:
|
756
|
+
record["epoch_times"] = self._convert_to_tensor(
|
757
|
+
record["epoch_times"][start_index : end_index + 1]
|
758
|
+
)
|
389
759
|
if self.include_values:
|
390
760
|
record["value_indicators"] = self._convert_to_tensor(
|
391
761
|
record["value_indicators"][start_index : end_index + 1]
|
@@ -406,7 +776,8 @@ class CehrGptDataCollator:
|
|
406
776
|
for i in reversed(list(range(0, end_index))):
|
407
777
|
current_token = record["input_ids"][i]
|
408
778
|
if current_token == self.ve_token_id:
|
409
|
-
|
779
|
+
# Plus one because slicing is right exclusive
|
780
|
+
end_index = i + 1
|
410
781
|
break
|
411
782
|
|
412
783
|
record["input_ids"] = record["input_ids"][0:end_index]
|
@@ -415,6 +786,14 @@ class CehrGptDataCollator:
|
|
415
786
|
if sample_packing and "attention_mask" in record:
|
416
787
|
record["attention_mask"] = record["attention_mask"][0:end_index]
|
417
788
|
|
789
|
+
if sample_packing and "position_ids" in record:
|
790
|
+
record["position_ids"] = record["position_ids"][0:end_index]
|
791
|
+
|
792
|
+
if "epoch_times" in record:
|
793
|
+
record["epoch_times"] = self._convert_to_tensor(
|
794
|
+
record["epoch_times"][0:end_index]
|
795
|
+
)
|
796
|
+
|
418
797
|
if self.include_values:
|
419
798
|
record["value_indicators"] = self._convert_to_tensor(
|
420
799
|
record["value_indicators"][0:end_index]
|
@@ -447,6 +826,17 @@ class CehrGptDataCollator:
|
|
447
826
|
),
|
448
827
|
]
|
449
828
|
)
|
829
|
+
if "epoch_times" in record:
|
830
|
+
record["epoch_times"] = torch.concat(
|
831
|
+
[
|
832
|
+
torch.zeros(
|
833
|
+
[record["epoch_times"][0]], dtype=torch.float32
|
834
|
+
),
|
835
|
+
self._convert_to_tensor(
|
836
|
+
record["epoch_times"][token_index:seq_length]
|
837
|
+
),
|
838
|
+
]
|
839
|
+
)
|
450
840
|
if self.include_values:
|
451
841
|
record["value_indicators"] = torch.concat(
|
452
842
|
[
|
@@ -485,7 +875,7 @@ class CehrGptDataCollator:
|
|
485
875
|
)
|
486
876
|
break
|
487
877
|
else:
|
488
|
-
start_index = seq_length - new_max_length
|
878
|
+
start_index = max(seq_length - new_max_length, 0)
|
489
879
|
end_index = seq_length
|
490
880
|
for i in range(start_index, end_index):
|
491
881
|
current_token = record["input_ids"][i]
|
@@ -495,6 +885,13 @@ class CehrGptDataCollator:
|
|
495
885
|
record["attention_mask"] = record["attention_mask"][
|
496
886
|
i:end_index
|
497
887
|
]
|
888
|
+
if sample_packing and "position_ids" in record:
|
889
|
+
record["position_ids"] = record["position_ids"][i:end_index]
|
890
|
+
|
891
|
+
if "epoch_times" in record:
|
892
|
+
record["epoch_times"] = self._convert_to_tensor(
|
893
|
+
record["epoch_times"][i:end_index]
|
894
|
+
)
|
498
895
|
if self.include_values:
|
499
896
|
record["value_indicators"] = record["value_indicators"][
|
500
897
|
i:end_index
|
@@ -514,6 +911,12 @@ class CehrGptDataCollator:
|
|
514
911
|
record["attention_mask"] = record["attention_mask"][
|
515
912
|
-new_max_length:
|
516
913
|
]
|
914
|
+
if sample_packing and "position_ids" in record:
|
915
|
+
record["position_ids"] = record["position_ids"][-new_max_length:]
|
916
|
+
if "epoch_times" in record:
|
917
|
+
record["epoch_times"] = self._convert_to_tensor(
|
918
|
+
record["epoch_times"][-new_max_length:]
|
919
|
+
)
|
517
920
|
if self.include_values:
|
518
921
|
record["value_indicators"] = record["value_indicators"][
|
519
922
|
-new_max_length:
|
@@ -524,36 +927,6 @@ class CehrGptDataCollator:
|
|
524
927
|
-new_max_length:
|
525
928
|
]
|
526
929
|
|
527
|
-
if not sample_packing:
|
528
|
-
# Finally we add the end token to the end of the sequence
|
529
|
-
record["input_ids"] = torch.concat(
|
530
|
-
[
|
531
|
-
self._convert_to_tensor(record["input_ids"]),
|
532
|
-
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
533
|
-
]
|
534
|
-
)
|
535
|
-
if self.include_values:
|
536
|
-
record["value_indicators"] = torch.concat(
|
537
|
-
[
|
538
|
-
self._convert_to_tensor(record["value_indicators"]),
|
539
|
-
self._convert_to_tensor([False]),
|
540
|
-
]
|
541
|
-
).to(torch.bool)
|
542
|
-
record["values"] = torch.concat(
|
543
|
-
[
|
544
|
-
self._convert_to_tensor(record["values"]),
|
545
|
-
self._convert_to_tensor(
|
546
|
-
[self.tokenizer.pad_value_token_id]
|
547
|
-
),
|
548
|
-
]
|
549
|
-
)
|
550
|
-
if self.include_ttv_prediction:
|
551
|
-
record["time_to_visits"] = torch.concat(
|
552
|
-
[
|
553
|
-
record["time_to_visits"],
|
554
|
-
self._convert_to_tensor([-100.0]),
|
555
|
-
]
|
556
|
-
)
|
557
930
|
return record
|
558
931
|
|
559
932
|
|
@@ -584,34 +957,46 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
584
957
|
|
585
958
|
for idx, example in enumerate(examples):
|
586
959
|
|
587
|
-
#
|
960
|
+
# We only add an end token if the patient sequence could fit in the entire context window
|
588
961
|
add_end_token = (
|
589
962
|
len(example["input_ids"]) <= self.max_position_embeddings
|
590
963
|
and self.add_end_token_in_sample_packing
|
591
964
|
)
|
592
|
-
|
965
|
+
# If the sample length exceeds the model's capacity, truncate this example
|
593
966
|
if len(example["input_ids"]) > self.max_position_embeddings:
|
594
967
|
example = self.generate_start_end_index(
|
595
|
-
example, self.max_position_embeddings
|
968
|
+
example, False, self.max_position_embeddings
|
596
969
|
)
|
597
970
|
|
971
|
+
add_eos_token = add_end_token | self.add_linear_prob_token
|
972
|
+
additional_tokens = []
|
973
|
+
if add_end_token:
|
974
|
+
additional_tokens.append(self.tokenizer.end_token_id)
|
975
|
+
elif self.add_linear_prob_token:
|
976
|
+
# Backward compatible
|
977
|
+
linear_prob_token_id = (
|
978
|
+
self.tokenizer.linear_token_id
|
979
|
+
if self.tokenizer.linear_token_id is not None
|
980
|
+
else self.tokenizer.oov_token_id
|
981
|
+
)
|
982
|
+
additional_tokens.append(linear_prob_token_id)
|
983
|
+
additional_tokens.append(self.tokenizer.pad_token_id)
|
598
984
|
input_ids = example["input_ids"]
|
599
985
|
# We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
|
600
986
|
# If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
|
601
987
|
# we only add [PAD] token to the end of the sequence because it's not finished
|
602
|
-
current_input_ids.extend(
|
603
|
-
list(input_ids)
|
604
|
-
+ (
|
605
|
-
[self.tokenizer.end_token_id, self.tokenizer.pad_token_id]
|
606
|
-
if add_end_token
|
607
|
-
else [self.tokenizer.pad_token_id]
|
608
|
-
)
|
609
|
-
)
|
988
|
+
current_input_ids.extend(list(input_ids) + additional_tokens)
|
610
989
|
current_attention_mask.extend(
|
611
|
-
np.ones_like(input_ids).tolist() + ([1, 0] if
|
990
|
+
np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
|
991
|
+
)
|
992
|
+
num_tokens_to_pad = 1 + int(add_eos_token)
|
993
|
+
current_position_ids.extend(
|
994
|
+
np.clip(
|
995
|
+
list(range(len(input_ids) + num_tokens_to_pad)),
|
996
|
+
0,
|
997
|
+
self.max_position_embeddings - 1,
|
998
|
+
)
|
612
999
|
)
|
613
|
-
num_tokens_to_pad = 1 + int(add_end_token)
|
614
|
-
current_position_ids.extend(list(range(len(input_ids) + num_tokens_to_pad)))
|
615
1000
|
if self.include_values:
|
616
1001
|
current_value_indicators.extend(
|
617
1002
|
list(example["value_indicators"]) + [False] * num_tokens_to_pad
|
@@ -633,9 +1018,10 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
633
1018
|
if "classifier_label" in example:
|
634
1019
|
current_labels.append(example["classifier_label"])
|
635
1020
|
|
636
|
-
assert (
|
637
|
-
|
638
|
-
|
1021
|
+
assert len(current_input_ids) <= self.max_tokens_per_batch, (
|
1022
|
+
f"The total number of tokens in the packed sequence should be less than {self.max_tokens_per_batch}\n"
|
1023
|
+
f"But the total number of tokens is: {len(current_input_ids)}"
|
1024
|
+
)
|
639
1025
|
packed_example = {
|
640
1026
|
"input_ids": current_input_ids,
|
641
1027
|
"attention_mask": current_attention_mask,
|