cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,28 @@
|
|
1
|
+
import copy
|
1
2
|
import random
|
2
|
-
from typing import Any, Dict
|
3
|
+
from typing import Any, Dict, List, Optional
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import torch
|
6
7
|
from torch.nn.utils.rnn import pad_sequence
|
8
|
+
from transformers.utils import logging
|
7
9
|
|
8
10
|
from cehrgpt.gpt_utils import (
|
9
11
|
DEMOGRAPHIC_PROMPT_SIZE,
|
10
12
|
collect_demographic_prompts_at_visits,
|
11
13
|
extract_time_interval_in_days,
|
14
|
+
extract_time_interval_in_hours,
|
12
15
|
is_att_token,
|
13
16
|
is_inpatient_att_token,
|
17
|
+
is_inpatient_hour_token,
|
18
|
+
is_visit_end,
|
14
19
|
random_slice_gpt_sequence,
|
15
20
|
)
|
16
21
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
17
22
|
|
23
|
+
TIME_TO_EVENT_MAX_TIME = 3650
|
18
24
|
INPATIENT_STAY_DURATION_LIMIT = 30
|
25
|
+
LOG = logging.get_logger("transformers")
|
19
26
|
|
20
27
|
|
21
28
|
class CehrGptDataCollator:
|
@@ -27,20 +34,18 @@ class CehrGptDataCollator:
|
|
27
34
|
include_values: bool = False,
|
28
35
|
include_ttv_prediction: bool = False,
|
29
36
|
use_sub_time_tokenization: bool = False,
|
37
|
+
include_motor_time_to_event: bool = False,
|
38
|
+
motor_tte_vocab_size: int = 0,
|
39
|
+
motor_num_time_pieces: int = 8,
|
30
40
|
pretraining: bool = True,
|
31
41
|
include_demographics: bool = False,
|
42
|
+
add_linear_prob_token: bool = False,
|
32
43
|
):
|
33
44
|
self.tokenizer = tokenizer
|
34
45
|
self.max_length = max_length
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
self.vs_token_id = tokenizer._convert_token_to_id("VS")
|
39
|
-
if self.vs_token_id == tokenizer._oov_token_id:
|
40
|
-
self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
|
41
|
-
self.ve_token_id = tokenizer._convert_token_to_id("VE")
|
42
|
-
if self.ve_token_id == tokenizer._oov_token_id:
|
43
|
-
self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
|
46
|
+
|
47
|
+
self.vs_token_id = tokenizer.vs_token_id
|
48
|
+
self.ve_token_id = tokenizer.ve_token_id
|
44
49
|
|
45
50
|
self.shuffle_records = shuffle_records
|
46
51
|
self.include_values = include_values
|
@@ -48,6 +53,20 @@ class CehrGptDataCollator:
|
|
48
53
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
49
54
|
self.pretraining = pretraining
|
50
55
|
self.include_demographics = include_demographics
|
56
|
+
self.add_linear_prob_token = add_linear_prob_token
|
57
|
+
|
58
|
+
# MOTOR TTE configuration
|
59
|
+
if include_motor_time_to_event:
|
60
|
+
assert motor_tte_vocab_size > 0, (
|
61
|
+
f"motor_tte_vocab_size must be greater than 0 "
|
62
|
+
f"when include_motor_time_to_event is set to True. "
|
63
|
+
f"But motor_tte_vocab_size: {motor_tte_vocab_size} is provided"
|
64
|
+
)
|
65
|
+
|
66
|
+
self.include_motor_time_to_event = include_motor_time_to_event
|
67
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
68
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
69
|
+
self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
|
51
70
|
|
52
71
|
if self.use_sub_time_tokenization:
|
53
72
|
token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
|
@@ -88,6 +107,8 @@ class CehrGptDataCollator:
|
|
88
107
|
):
|
89
108
|
return -100
|
90
109
|
return time_to_visit
|
110
|
+
elif is_inpatient_hour_token(c):
|
111
|
+
return extract_time_interval_in_hours(c) / 24
|
91
112
|
return -100
|
92
113
|
except ValueError:
|
93
114
|
return -100
|
@@ -95,8 +116,8 @@ class CehrGptDataCollator:
|
|
95
116
|
return [float(default_value(_)) for _ in concept_ids]
|
96
117
|
|
97
118
|
def __call__(self, examples):
|
98
|
-
|
99
|
-
examples = [self.generate_start_end_index(_) for _ in examples]
|
119
|
+
sample_packing = getattr(self, "sample_packing", False)
|
120
|
+
examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
|
100
121
|
examples = [self.random_sort(_) for _ in examples]
|
101
122
|
batch = {}
|
102
123
|
|
@@ -105,9 +126,12 @@ class CehrGptDataCollator:
|
|
105
126
|
self._try_reverse_tensor(self._convert_to_tensor(example["input_ids"]))
|
106
127
|
for example in examples
|
107
128
|
]
|
129
|
+
|
108
130
|
batch_attention_mask = [
|
109
131
|
self._try_reverse_tensor(
|
110
|
-
torch.
|
132
|
+
self._convert_to_tensor(example["attention_mask"]).to(torch.float)
|
133
|
+
if "attention_mask" in example
|
134
|
+
else torch.ones_like(
|
111
135
|
self._convert_to_tensor(example["input_ids"]), dtype=torch.float
|
112
136
|
)
|
113
137
|
)
|
@@ -128,16 +152,40 @@ class CehrGptDataCollator:
|
|
128
152
|
)
|
129
153
|
assert batch["input_ids"].shape[1] <= self.max_length
|
130
154
|
assert batch["attention_mask"].shape[1] <= self.max_length
|
155
|
+
assert batch["attention_mask"].shape[1] == batch["input_ids"].shape[1], (
|
156
|
+
f'batch["attention_mask"].shape[1]: {batch["attention_mask"].shape[1]}, '
|
157
|
+
f'batch["input_ids"].shape[1]: {batch["input_ids"].shape[1]}'
|
158
|
+
)
|
159
|
+
assert batch["input_ids"].max() < self.tokenizer.vocab_size, (
|
160
|
+
f"batch['input_ids'].max(): {batch['input_ids'].max()} must be smaller than "
|
161
|
+
f"self.tokenizer.vocab_size: {self.tokenizer.vocab_size}. "
|
162
|
+
f"batch['input_ids']: {batch['input_ids']} "
|
163
|
+
)
|
131
164
|
|
132
|
-
if
|
133
|
-
|
165
|
+
if "position_ids" in examples[0]:
|
166
|
+
batch_position_ids = [
|
167
|
+
self._try_reverse_tensor(
|
168
|
+
self._convert_to_tensor(example["position_ids"])
|
169
|
+
)
|
170
|
+
for example in examples
|
171
|
+
]
|
172
|
+
# Pad sequences to the max length in the batch
|
173
|
+
batch["position_ids"] = self._try_reverse_tensor(
|
134
174
|
pad_sequence(
|
135
|
-
|
175
|
+
batch_position_ids,
|
136
176
|
batch_first=True,
|
137
|
-
padding_value
|
177
|
+
padding_value=0,
|
138
178
|
).to(torch.int64)
|
139
179
|
)
|
140
180
|
|
181
|
+
if self.pretraining:
|
182
|
+
batch["labels"] = torch.where(
|
183
|
+
(batch["input_ids"] != self.tokenizer.pad_token_id)
|
184
|
+
& batch["attention_mask"].to(torch.bool),
|
185
|
+
batch["input_ids"],
|
186
|
+
-100,
|
187
|
+
)
|
188
|
+
|
141
189
|
if self.use_sub_time_tokenization:
|
142
190
|
time_token_indicators = torch.isin(batch["input_ids"], self.time_tokens)
|
143
191
|
masked_tokens = batch["input_ids"].clone()
|
@@ -167,10 +215,130 @@ class CehrGptDataCollator:
|
|
167
215
|
)
|
168
216
|
)
|
169
217
|
|
218
|
+
if self.include_motor_time_to_event:
|
219
|
+
examples_with_motor_tte = [
|
220
|
+
self.create_time_to_event_labels(_) for _ in examples
|
221
|
+
]
|
222
|
+
batch_motor_time_to_event_vectors = [
|
223
|
+
self._try_reverse_tensor(
|
224
|
+
self._convert_to_tensor(example["time_to_event_vectors"])
|
225
|
+
)
|
226
|
+
for example in examples_with_motor_tte
|
227
|
+
]
|
228
|
+
batch_motor_event_indicators = [
|
229
|
+
self._try_reverse_tensor(
|
230
|
+
self._convert_to_tensor(example["event_indicators"])
|
231
|
+
)
|
232
|
+
for example in examples_with_motor_tte
|
233
|
+
]
|
234
|
+
batch_motor_time_to_event_to_include = [
|
235
|
+
self._try_reverse_tensor(
|
236
|
+
self._convert_to_tensor(example["time_to_event_to_include"])
|
237
|
+
)
|
238
|
+
for example in examples_with_motor_tte
|
239
|
+
]
|
240
|
+
batch_motor_time_indicators = [
|
241
|
+
self._try_reverse_tensor(
|
242
|
+
self._convert_to_tensor(example["time_indicators"])
|
243
|
+
)
|
244
|
+
for example in examples_with_motor_tte
|
245
|
+
]
|
246
|
+
|
247
|
+
batch_motor_time_to_event_vectors = torch.concat(
|
248
|
+
batch_motor_time_to_event_vectors, dim=0
|
249
|
+
).to(torch.float32)
|
250
|
+
|
251
|
+
# If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
|
252
|
+
# we only create the labels when any example has more than one visit
|
253
|
+
if batch_motor_time_to_event_vectors.dim() <= 1:
|
254
|
+
LOG.warning(
|
255
|
+
"There are no MOTOR TTE labels generated for this batch "
|
256
|
+
"because every example in this batch only contains one visit."
|
257
|
+
)
|
258
|
+
else:
|
259
|
+
batch_size = len(examples)
|
260
|
+
length, num_time_pieces, motor_tte_vocab_size = (
|
261
|
+
batch_motor_time_to_event_vectors.shape
|
262
|
+
)
|
263
|
+
padded_length = batch_size - length % batch_size
|
264
|
+
batch["motor_time_to_event_vectors"] = (
|
265
|
+
torch.concat(
|
266
|
+
[
|
267
|
+
batch_motor_time_to_event_vectors,
|
268
|
+
torch.full(
|
269
|
+
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
270
|
+
0.0,
|
271
|
+
),
|
272
|
+
],
|
273
|
+
dim=0,
|
274
|
+
)
|
275
|
+
.reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
|
276
|
+
.to(torch.float32)
|
277
|
+
)
|
278
|
+
|
279
|
+
# Motor event indicators that indicate there is an event occurred in this time interval
|
280
|
+
batch_motor_event_indicators = torch.concat(
|
281
|
+
batch_motor_event_indicators, dim=0
|
282
|
+
).to(torch.bool)
|
283
|
+
batch["motor_event_indicators"] = (
|
284
|
+
torch.concat(
|
285
|
+
[
|
286
|
+
batch_motor_event_indicators,
|
287
|
+
torch.full(
|
288
|
+
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
289
|
+
False,
|
290
|
+
),
|
291
|
+
],
|
292
|
+
dim=0,
|
293
|
+
)
|
294
|
+
.reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
|
295
|
+
.to(torch.bool)
|
296
|
+
)
|
297
|
+
|
298
|
+
# Input to indicate whether the visit should be included for TTE predictions
|
299
|
+
batch_motor_time_to_event_to_include = torch.concat(
|
300
|
+
batch_motor_time_to_event_to_include, dim=0
|
301
|
+
).to(torch.bool)
|
302
|
+
batch["motor_time_to_event_to_include"] = (
|
303
|
+
torch.concat(
|
304
|
+
[
|
305
|
+
batch_motor_time_to_event_to_include,
|
306
|
+
torch.full((padded_length,), False),
|
307
|
+
],
|
308
|
+
dim=0,
|
309
|
+
).to(torch.bool)
|
310
|
+
).reshape((batch_size, -1))
|
311
|
+
|
312
|
+
# Motor time indicators that indicate whether there are neither clinical events nor censor events
|
313
|
+
batch_motor_time_indicators = torch.concat(
|
314
|
+
batch_motor_time_indicators, dim=0
|
315
|
+
).to(torch.bool)
|
316
|
+
batch["motor_time_indicators"] = (
|
317
|
+
torch.concat(
|
318
|
+
[
|
319
|
+
batch_motor_time_indicators,
|
320
|
+
torch.full(
|
321
|
+
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
322
|
+
False,
|
323
|
+
),
|
324
|
+
],
|
325
|
+
dim=0,
|
326
|
+
)
|
327
|
+
.reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
|
328
|
+
.to(torch.bool)
|
329
|
+
)
|
330
|
+
|
331
|
+
batch["motor_end_index"] = torch.concat(
|
332
|
+
[
|
333
|
+
torch.full((length, 1), 1, dtype=torch.int32),
|
334
|
+
torch.full((padded_length, 1), 0, dtype=torch.int32),
|
335
|
+
]
|
336
|
+
).reshape((batch_size, -1))
|
337
|
+
|
170
338
|
if self.include_values:
|
171
339
|
batch_value_indicators = [
|
172
340
|
self._try_reverse_tensor(
|
173
|
-
self._convert_to_tensor(example["value_indicators"])
|
341
|
+
self._convert_to_tensor(example["value_indicators"]).to(torch.bool)
|
174
342
|
)
|
175
343
|
for example in examples
|
176
344
|
]
|
@@ -178,7 +346,6 @@ class CehrGptDataCollator:
|
|
178
346
|
self._try_reverse_tensor(self._convert_to_tensor(example["values"]))
|
179
347
|
for example in examples
|
180
348
|
]
|
181
|
-
|
182
349
|
batch["value_indicators"] = self._try_reverse_tensor(
|
183
350
|
pad_sequence(
|
184
351
|
batch_value_indicators, batch_first=True, padding_value=False
|
@@ -200,44 +367,248 @@ class CehrGptDataCollator:
|
|
200
367
|
batch["value_indicators"], batch["values"].clone(), -100
|
201
368
|
)
|
202
369
|
|
370
|
+
bz = len(examples)
|
203
371
|
if "person_id" in examples[0]:
|
204
|
-
batch["person_id"] =
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
372
|
+
batch["person_id"] = (
|
373
|
+
torch.cat(
|
374
|
+
[
|
375
|
+
self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
|
376
|
+
for example in examples
|
377
|
+
],
|
378
|
+
dim=0,
|
379
|
+
)
|
380
|
+
.to(torch.int32)
|
381
|
+
.reshape(bz, -1)
|
382
|
+
)
|
211
383
|
|
212
384
|
if "index_date" in examples[0]:
|
213
385
|
batch["index_date"] = torch.cat(
|
214
386
|
[
|
215
|
-
|
387
|
+
torch.tensor(example["index_date"], dtype=torch.float64).reshape(
|
388
|
+
-1, 1
|
389
|
+
)
|
216
390
|
for example in examples
|
217
391
|
],
|
218
392
|
dim=0,
|
219
|
-
).
|
393
|
+
).reshape(bz, -1)
|
220
394
|
|
221
395
|
if "age_at_index" in examples[0]:
|
222
|
-
batch["age_at_index"] =
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
396
|
+
batch["age_at_index"] = (
|
397
|
+
torch.cat(
|
398
|
+
[
|
399
|
+
self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
|
400
|
+
for example in examples
|
401
|
+
],
|
402
|
+
dim=0,
|
403
|
+
)
|
404
|
+
.to(torch.float32)
|
405
|
+
.reshape(bz, -1)
|
406
|
+
)
|
229
407
|
|
230
408
|
if "classifier_label" in examples[0]:
|
231
|
-
batch["classifier_label"] =
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
409
|
+
batch["classifier_label"] = (
|
410
|
+
torch.cat(
|
411
|
+
[
|
412
|
+
self._convert_to_tensor(example["classifier_label"]).reshape(
|
413
|
+
-1, 1
|
414
|
+
)
|
415
|
+
for example in examples
|
416
|
+
],
|
417
|
+
dim=0,
|
418
|
+
)
|
419
|
+
.to(torch.float32)
|
420
|
+
.reshape(bz, -1)
|
421
|
+
)
|
238
422
|
|
239
423
|
return batch
|
240
424
|
|
425
|
+
def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
426
|
+
"""
|
427
|
+
Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
|
428
|
+
|
429
|
+
Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
|
430
|
+
to the occurrence of future motor-related events.
|
431
|
+
|
432
|
+
Args:
|
433
|
+
record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
|
434
|
+
This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
Dict[str, Any]: The updated input record with added keys:
|
438
|
+
- "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
|
439
|
+
- "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
|
440
|
+
"""
|
441
|
+
input_ids = record["input_ids"]
|
442
|
+
sample_packing = getattr(self, "sample_packing", False)
|
443
|
+
|
444
|
+
if isinstance(input_ids, torch.Tensor):
|
445
|
+
input_ids = input_ids.detach().tolist()
|
446
|
+
|
447
|
+
# This potentially contains packed samples, we need to handle that
|
448
|
+
packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
449
|
+
pad_indices = []
|
450
|
+
if sample_packing:
|
451
|
+
# We start from the first index
|
452
|
+
for i in range(len(packed_concept_ids)):
|
453
|
+
if packed_concept_ids[i] == self.tokenizer.pad_token:
|
454
|
+
# If we encounter consecutive pads, we should break out of the loop
|
455
|
+
if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
|
456
|
+
break
|
457
|
+
pad_indices.append(i)
|
458
|
+
|
459
|
+
# If we did not find a pad, that means the whole sequence belongs to one sample
|
460
|
+
if len(pad_indices) == 0:
|
461
|
+
pad_indices.append(len(packed_concept_ids))
|
462
|
+
|
463
|
+
timepiece_time_to_event_vectors = []
|
464
|
+
timepiece_event_indicators = []
|
465
|
+
timepiece_indicators = []
|
466
|
+
time_to_event_to_includes = []
|
467
|
+
|
468
|
+
for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
|
469
|
+
concept_ids = packed_concept_ids[start_index:end_index]
|
470
|
+
if concept_ids[0] == self.tokenizer.pad_token:
|
471
|
+
concept_ids.pop(0)
|
472
|
+
time_to_event_vectors = []
|
473
|
+
global_event_indicators = []
|
474
|
+
|
475
|
+
# First collect TTE data in reverse chronological order
|
476
|
+
censor_times = []
|
477
|
+
time_to_event_data: List[Dict[str, int]] = []
|
478
|
+
time_to_event_dict: Dict[str, int] = {}
|
479
|
+
time_to_event_to_include: List[bool] = []
|
480
|
+
next_future_visit_concepts = set()
|
481
|
+
time_interval = 0
|
482
|
+
|
483
|
+
# Reverse walk through concept_ids to calculate TTE from each [VE] point
|
484
|
+
for concept_id in reversed(concept_ids):
|
485
|
+
if is_visit_end(concept_id):
|
486
|
+
# Update TTE for existing concepts, or add new ones seen in this visit
|
487
|
+
for existing_concept_id in list(time_to_event_dict.keys()):
|
488
|
+
if existing_concept_id in next_future_visit_concepts:
|
489
|
+
time_to_event_dict[existing_concept_id] = time_interval
|
490
|
+
else:
|
491
|
+
time_to_event_dict[existing_concept_id] += time_interval
|
492
|
+
|
493
|
+
for next_concept_id in next_future_visit_concepts:
|
494
|
+
if next_concept_id not in time_to_event_dict:
|
495
|
+
time_to_event_dict[next_concept_id] = time_interval
|
496
|
+
|
497
|
+
# If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
|
498
|
+
# previous visit
|
499
|
+
time_to_event_to_include.append(time_interval > 0)
|
500
|
+
time_to_event_data.append(copy.deepcopy(time_to_event_dict))
|
501
|
+
# Record the censor time at the end of the visit
|
502
|
+
if censor_times:
|
503
|
+
censor_times.append(censor_times[-1] + time_interval)
|
504
|
+
else:
|
505
|
+
censor_times.append(time_interval)
|
506
|
+
time_interval = 0
|
507
|
+
next_future_visit_concepts.clear()
|
508
|
+
|
509
|
+
elif is_att_token(concept_id):
|
510
|
+
time_interval += extract_time_interval_in_days(concept_id)
|
511
|
+
|
512
|
+
elif self.tokenizer.is_motor_time_to_event_code(concept_id):
|
513
|
+
next_future_visit_concepts.add(concept_id)
|
514
|
+
|
515
|
+
if len(time_to_event_data) == 0:
|
516
|
+
LOG.info(
|
517
|
+
"Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
|
518
|
+
"It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
|
519
|
+
len(concept_ids),
|
520
|
+
concept_ids[-10:],
|
521
|
+
)
|
522
|
+
continue
|
523
|
+
|
524
|
+
# Reverse back to chronological order for final labels
|
525
|
+
time_to_event_data.reverse()
|
526
|
+
censor_times.reverse()
|
527
|
+
time_to_event_to_include.reverse()
|
528
|
+
|
529
|
+
for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
|
530
|
+
time_to_event_vector = np.full(
|
531
|
+
self.tokenizer.motor_tte_vocab_size,
|
532
|
+
fill_value=censor_time,
|
533
|
+
dtype=np.int32,
|
534
|
+
)
|
535
|
+
event_indicator = np.zeros(
|
536
|
+
self.tokenizer.motor_tte_vocab_size,
|
537
|
+
dtype=np.int32,
|
538
|
+
)
|
539
|
+
visit_token_ids = [
|
540
|
+
self.tokenizer.get_motor_token_id(concept_id)
|
541
|
+
for concept_id in visit_tte_data.keys()
|
542
|
+
]
|
543
|
+
visit_tte_values = list(visit_tte_data.values())
|
544
|
+
|
545
|
+
time_to_event_vector[visit_token_ids] = visit_tte_values
|
546
|
+
event_indicator[visit_token_ids] = 1 # not censored (event occurred)
|
547
|
+
|
548
|
+
time_to_event_vectors.append(time_to_event_vector)
|
549
|
+
global_event_indicators.append(event_indicator)
|
550
|
+
|
551
|
+
time_to_event_vectors = np.asarray(time_to_event_vectors)
|
552
|
+
global_event_indicators = np.asarray(global_event_indicators).astype(bool)
|
553
|
+
n_visits = len(time_to_event_vectors)
|
554
|
+
|
555
|
+
timepiece_time_to_event_vector = np.full(
|
556
|
+
(
|
557
|
+
self.motor_num_time_pieces,
|
558
|
+
n_visits,
|
559
|
+
self.tokenizer.motor_tte_vocab_size,
|
560
|
+
),
|
561
|
+
fill_value=0,
|
562
|
+
dtype=np.int32,
|
563
|
+
)
|
564
|
+
timepiece_event_indicator = np.zeros(
|
565
|
+
(
|
566
|
+
self.motor_num_time_pieces,
|
567
|
+
n_visits,
|
568
|
+
self.tokenizer.motor_tte_vocab_size,
|
569
|
+
),
|
570
|
+
dtype=bool,
|
571
|
+
)
|
572
|
+
timepiece_indicator = np.zeros(
|
573
|
+
(
|
574
|
+
self.motor_num_time_pieces,
|
575
|
+
n_visits,
|
576
|
+
self.tokenizer.motor_tte_vocab_size,
|
577
|
+
),
|
578
|
+
dtype=bool,
|
579
|
+
)
|
580
|
+
|
581
|
+
# Putting the event time and censor time into the corresponding time bins
|
582
|
+
for bin_num in range(self.motor_num_time_pieces):
|
583
|
+
start = self.motor_time_interval * bin_num
|
584
|
+
end = self.motor_time_interval * (bin_num + 1)
|
585
|
+
time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
|
586
|
+
timepiece_time_to_event_vector[bin_num] = time_in_bin
|
587
|
+
event_indicator = (
|
588
|
+
global_event_indicators
|
589
|
+
& (start <= time_to_event_vectors)
|
590
|
+
& (time_to_event_vectors < end)
|
591
|
+
)
|
592
|
+
timepiece_event_indicator[bin_num] = event_indicator
|
593
|
+
timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
|
594
|
+
|
595
|
+
timepiece_time_to_event_vectors.append(
|
596
|
+
timepiece_time_to_event_vector.swapaxes(0, 1)
|
597
|
+
)
|
598
|
+
timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
|
599
|
+
timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
|
600
|
+
time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
|
601
|
+
|
602
|
+
record["time_to_event_vectors"] = np.concatenate(
|
603
|
+
timepiece_time_to_event_vectors, axis=0
|
604
|
+
)
|
605
|
+
record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
|
606
|
+
record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
|
607
|
+
record["time_to_event_to_include"] = np.concatenate(
|
608
|
+
time_to_event_to_includes, axis=0
|
609
|
+
)
|
610
|
+
return record
|
611
|
+
|
241
612
|
def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
242
613
|
|
243
614
|
if not self.shuffle_records:
|
@@ -273,53 +644,82 @@ class CehrGptDataCollator:
|
|
273
644
|
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
274
645
|
return record
|
275
646
|
|
276
|
-
def generate_start_end_index(
|
647
|
+
def generate_start_end_index(
|
648
|
+
self,
|
649
|
+
record: Dict[str, Any],
|
650
|
+
sample_packing: bool,
|
651
|
+
max_length_allowed: Optional[int] = None,
|
652
|
+
) -> Dict[str, Any]:
|
277
653
|
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
278
654
|
# concept_ids will be used to for time to event predictions and identifying the visit starts
|
655
|
+
max_length_allowed = (
|
656
|
+
self.max_length if max_length_allowed is None else max_length_allowed
|
657
|
+
)
|
279
658
|
input_ids = record["input_ids"]
|
280
659
|
if isinstance(input_ids, torch.Tensor):
|
281
660
|
input_ids = input_ids.detach().tolist()
|
282
661
|
concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
283
662
|
seq_length = len(record["input_ids"])
|
284
|
-
|
663
|
+
|
664
|
+
# Subtract one for the [END] token when sample_packing is not enabled
|
665
|
+
new_max_length = (
|
666
|
+
max_length_allowed if sample_packing else max_length_allowed - 1
|
667
|
+
)
|
668
|
+
|
669
|
+
if self.include_ttv_prediction:
|
670
|
+
record["time_to_visits"] = torch.concat(
|
671
|
+
[self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
|
672
|
+
)
|
673
|
+
|
674
|
+
# If linear token exists, we will use it, otherwise we default to the OOV token
|
675
|
+
linear_token_id = (
|
676
|
+
self.tokenizer.linear_token_id
|
677
|
+
if self.tokenizer.linear_token_id
|
678
|
+
else self.tokenizer.oov_token_id
|
679
|
+
)
|
680
|
+
eos_token = (
|
681
|
+
linear_token_id
|
682
|
+
if self.add_linear_prob_token
|
683
|
+
else self.tokenizer.end_token_id
|
684
|
+
)
|
285
685
|
|
286
686
|
# Return the record directly if the actual sequence length is less than the max sequence
|
287
687
|
if seq_length <= new_max_length:
|
288
|
-
|
289
|
-
[
|
290
|
-
self._convert_to_tensor(record["input_ids"]),
|
291
|
-
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
292
|
-
]
|
293
|
-
)
|
294
|
-
if self.include_values:
|
295
|
-
record["value_indicators"] = torch.concat(
|
296
|
-
[
|
297
|
-
self._convert_to_tensor(record["value_indicators"]),
|
298
|
-
self._convert_to_tensor([False]),
|
299
|
-
]
|
300
|
-
).to(torch.bool)
|
301
|
-
record["values"] = torch.concat(
|
302
|
-
[
|
303
|
-
self._convert_to_tensor(record["values"]),
|
304
|
-
self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
|
305
|
-
]
|
306
|
-
)
|
307
|
-
if self.include_ttv_prediction:
|
308
|
-
record["time_to_visits"] = torch.concat(
|
688
|
+
if not sample_packing:
|
689
|
+
record["input_ids"] = torch.concat(
|
309
690
|
[
|
310
|
-
self._convert_to_tensor(
|
311
|
-
|
312
|
-
),
|
313
|
-
self._convert_to_tensor([-100.0]),
|
691
|
+
self._convert_to_tensor(record["input_ids"]),
|
692
|
+
self._convert_to_tensor([eos_token]),
|
314
693
|
]
|
315
694
|
)
|
316
|
-
|
695
|
+
if self.include_values:
|
696
|
+
record["value_indicators"] = torch.concat(
|
697
|
+
[
|
698
|
+
self._convert_to_tensor(record["value_indicators"]),
|
699
|
+
self._convert_to_tensor([False]),
|
700
|
+
]
|
701
|
+
).to(torch.bool)
|
702
|
+
record["values"] = torch.concat(
|
703
|
+
[
|
704
|
+
self._convert_to_tensor(record["values"]),
|
705
|
+
self._convert_to_tensor(
|
706
|
+
[self.tokenizer.pad_value_token_id]
|
707
|
+
),
|
708
|
+
]
|
709
|
+
)
|
710
|
+
if self.include_ttv_prediction:
|
711
|
+
record["time_to_visits"] = torch.concat(
|
712
|
+
[
|
713
|
+
record["time_to_visits"],
|
714
|
+
self._convert_to_tensor([-100.0]),
|
715
|
+
]
|
716
|
+
)
|
317
717
|
return record
|
318
718
|
|
319
719
|
if self.pretraining:
|
320
720
|
# There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
|
321
721
|
# prompt depending on the new starting point
|
322
|
-
if random.random() < 0.5:
|
722
|
+
if random.random() < 0.5 and not sample_packing:
|
323
723
|
start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
|
324
724
|
concept_ids, new_max_length
|
325
725
|
)
|
@@ -347,10 +747,19 @@ class CehrGptDataCollator:
|
|
347
747
|
for i in reversed(list(range(0, end_index))):
|
348
748
|
current_token = record["input_ids"][i]
|
349
749
|
if current_token == self.ve_token_id:
|
350
|
-
|
750
|
+
# Plus one because slicing is right exclusive
|
751
|
+
end_index = i + 1
|
351
752
|
break
|
352
753
|
|
353
754
|
record["input_ids"] = record["input_ids"][0:end_index]
|
755
|
+
|
756
|
+
# We want to make sure we take the subset of attention_mask in sample packing if this field is available
|
757
|
+
if sample_packing and "attention_mask" in record:
|
758
|
+
record["attention_mask"] = record["attention_mask"][0:end_index]
|
759
|
+
|
760
|
+
if sample_packing and "position_ids" in record:
|
761
|
+
record["position_ids"] = record["position_ids"][0:end_index]
|
762
|
+
|
354
763
|
if self.include_values:
|
355
764
|
record["value_indicators"] = self._convert_to_tensor(
|
356
765
|
record["value_indicators"][0:end_index]
|
@@ -364,7 +773,7 @@ class CehrGptDataCollator:
|
|
364
773
|
)
|
365
774
|
return record
|
366
775
|
else:
|
367
|
-
if self.include_demographics:
|
776
|
+
if self.include_demographics and not sample_packing:
|
368
777
|
# We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
|
369
778
|
demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
|
370
779
|
concept_ids
|
@@ -427,6 +836,12 @@ class CehrGptDataCollator:
|
|
427
836
|
current_token = record["input_ids"][i]
|
428
837
|
if current_token == self.vs_token_id:
|
429
838
|
record["input_ids"] = record["input_ids"][i:end_index]
|
839
|
+
if sample_packing and "attention_mask" in record:
|
840
|
+
record["attention_mask"] = record["attention_mask"][
|
841
|
+
i:end_index
|
842
|
+
]
|
843
|
+
if sample_packing and "position_ids" in record:
|
844
|
+
record["position_ids"] = record["position_ids"][i:end_index]
|
430
845
|
if self.include_values:
|
431
846
|
record["value_indicators"] = record["value_indicators"][
|
432
847
|
i:end_index
|
@@ -442,6 +857,12 @@ class CehrGptDataCollator:
|
|
442
857
|
# We simply take the last new_max_length number of tokens from the patient sequence
|
443
858
|
if len(record["input_ids"]) > new_max_length:
|
444
859
|
record["input_ids"] = record["input_ids"][-new_max_length:]
|
860
|
+
if sample_packing and "attention_mask" in record:
|
861
|
+
record["attention_mask"] = record["attention_mask"][
|
862
|
+
-new_max_length:
|
863
|
+
]
|
864
|
+
if sample_packing and "position_ids" in record:
|
865
|
+
record["position_ids"] = record["position_ids"][-new_max_length:]
|
445
866
|
if self.include_values:
|
446
867
|
record["value_indicators"] = record["value_indicators"][
|
447
868
|
-new_max_length:
|
@@ -452,31 +873,148 @@ class CehrGptDataCollator:
|
|
452
873
|
-new_max_length:
|
453
874
|
]
|
454
875
|
|
455
|
-
|
456
|
-
|
457
|
-
[
|
458
|
-
self._convert_to_tensor(record["input_ids"]),
|
459
|
-
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
460
|
-
]
|
461
|
-
)
|
462
|
-
if self.include_values:
|
463
|
-
record["value_indicators"] = torch.concat(
|
876
|
+
if not sample_packing:
|
877
|
+
# Finally we add the end token to the end of the sequence
|
878
|
+
record["input_ids"] = torch.concat(
|
464
879
|
[
|
465
|
-
self._convert_to_tensor(record["
|
466
|
-
self._convert_to_tensor([
|
467
|
-
]
|
468
|
-
).to(torch.bool)
|
469
|
-
record["values"] = torch.concat(
|
470
|
-
[
|
471
|
-
self._convert_to_tensor(record["values"]),
|
472
|
-
self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
|
473
|
-
]
|
474
|
-
)
|
475
|
-
if self.include_ttv_prediction:
|
476
|
-
record["time_to_visits"] = torch.concat(
|
477
|
-
[
|
478
|
-
record["time_to_visits"],
|
479
|
-
self._convert_to_tensor([-100.0]),
|
880
|
+
self._convert_to_tensor(record["input_ids"]),
|
881
|
+
self._convert_to_tensor([eos_token]),
|
480
882
|
]
|
481
883
|
)
|
884
|
+
if self.include_values:
|
885
|
+
record["value_indicators"] = torch.concat(
|
886
|
+
[
|
887
|
+
self._convert_to_tensor(record["value_indicators"]),
|
888
|
+
self._convert_to_tensor([False]),
|
889
|
+
]
|
890
|
+
).to(torch.bool)
|
891
|
+
record["values"] = torch.concat(
|
892
|
+
[
|
893
|
+
self._convert_to_tensor(record["values"]),
|
894
|
+
self._convert_to_tensor(
|
895
|
+
[self.tokenizer.pad_value_token_id]
|
896
|
+
),
|
897
|
+
]
|
898
|
+
)
|
899
|
+
if self.include_ttv_prediction:
|
900
|
+
record["time_to_visits"] = torch.concat(
|
901
|
+
[
|
902
|
+
record["time_to_visits"],
|
903
|
+
self._convert_to_tensor([-100.0]),
|
904
|
+
]
|
905
|
+
)
|
482
906
|
return record
|
907
|
+
|
908
|
+
|
909
|
+
class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
910
|
+
def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
|
911
|
+
self.max_tokens_per_batch = max_tokens
|
912
|
+
self.max_position_embeddings = max_position_embeddings
|
913
|
+
self.sample_packing = True
|
914
|
+
self.add_end_token_in_sample_packing = kwargs.pop(
|
915
|
+
"add_end_token_in_sample_packing", False
|
916
|
+
)
|
917
|
+
super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
|
918
|
+
|
919
|
+
def __call__(self, examples):
|
920
|
+
current_input_ids = []
|
921
|
+
current_attention_mask = []
|
922
|
+
current_position_ids = []
|
923
|
+
current_value_indicators = []
|
924
|
+
current_values = []
|
925
|
+
|
926
|
+
# Demographics
|
927
|
+
current_person_ids = []
|
928
|
+
current_index_dates = []
|
929
|
+
|
930
|
+
# Binary classification inputs
|
931
|
+
current_ages = []
|
932
|
+
current_labels = []
|
933
|
+
|
934
|
+
for idx, example in enumerate(examples):
|
935
|
+
|
936
|
+
# We only add an end token if the patient sequence could fit in the entire context window
|
937
|
+
add_end_token = (
|
938
|
+
len(example["input_ids"]) <= self.max_position_embeddings
|
939
|
+
and self.add_end_token_in_sample_packing
|
940
|
+
)
|
941
|
+
# If the sample length exceeds the model's capacity, truncate this example
|
942
|
+
if len(example["input_ids"]) > self.max_position_embeddings:
|
943
|
+
example = self.generate_start_end_index(
|
944
|
+
example, False, self.max_position_embeddings
|
945
|
+
)
|
946
|
+
|
947
|
+
add_eos_token = add_end_token | self.add_linear_prob_token
|
948
|
+
additional_tokens = []
|
949
|
+
if add_end_token:
|
950
|
+
additional_tokens.append(self.tokenizer.end_token_id)
|
951
|
+
elif self.add_linear_prob_token:
|
952
|
+
# Backward compatible
|
953
|
+
linear_prob_token_id = (
|
954
|
+
self.tokenizer.linear_token_id
|
955
|
+
if self.tokenizer.linear_token_id is not None
|
956
|
+
else self.tokenizer.oov_token_id
|
957
|
+
)
|
958
|
+
additional_tokens.append(linear_prob_token_id)
|
959
|
+
additional_tokens.append(self.tokenizer.pad_token_id)
|
960
|
+
input_ids = example["input_ids"]
|
961
|
+
# We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
|
962
|
+
# If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
|
963
|
+
# we only add [PAD] token to the end of the sequence because it's not finished
|
964
|
+
current_input_ids.extend(list(input_ids) + additional_tokens)
|
965
|
+
current_attention_mask.extend(
|
966
|
+
np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
|
967
|
+
)
|
968
|
+
num_tokens_to_pad = 1 + int(add_eos_token)
|
969
|
+
current_position_ids.extend(
|
970
|
+
np.clip(
|
971
|
+
list(range(len(input_ids) + num_tokens_to_pad)),
|
972
|
+
0,
|
973
|
+
self.max_position_embeddings - 1,
|
974
|
+
)
|
975
|
+
)
|
976
|
+
if self.include_values:
|
977
|
+
current_value_indicators.extend(
|
978
|
+
list(example["value_indicators"]) + [False] * num_tokens_to_pad
|
979
|
+
)
|
980
|
+
current_values.extend(
|
981
|
+
list(example["values"])
|
982
|
+
+ [self.tokenizer.pad_value_token_id] * num_tokens_to_pad
|
983
|
+
)
|
984
|
+
|
985
|
+
if "person_id" in example:
|
986
|
+
current_person_ids.append(example["person_id"])
|
987
|
+
|
988
|
+
if "index_date" in example:
|
989
|
+
current_index_dates.append(example["index_date"])
|
990
|
+
|
991
|
+
if "age_at_index" in example:
|
992
|
+
current_ages.append(example["age_at_index"])
|
993
|
+
|
994
|
+
if "classifier_label" in example:
|
995
|
+
current_labels.append(example["classifier_label"])
|
996
|
+
|
997
|
+
assert len(current_input_ids) <= self.max_tokens_per_batch, (
|
998
|
+
f"The total number of tokens in the packed sequence should be less than {self.max_tokens_per_batch}\n"
|
999
|
+
f"But the total number of tokens is: {len(current_input_ids)}"
|
1000
|
+
)
|
1001
|
+
packed_example = {
|
1002
|
+
"input_ids": current_input_ids,
|
1003
|
+
"attention_mask": current_attention_mask,
|
1004
|
+
"position_ids": current_position_ids,
|
1005
|
+
}
|
1006
|
+
if self.include_values:
|
1007
|
+
packed_example.update({"value_indicators": current_value_indicators})
|
1008
|
+
packed_example.update({"values": current_values})
|
1009
|
+
|
1010
|
+
if current_labels:
|
1011
|
+
packed_example.update(
|
1012
|
+
{
|
1013
|
+
"person_id": current_person_ids,
|
1014
|
+
"index_date": current_index_dates,
|
1015
|
+
"age_at_index": current_ages,
|
1016
|
+
"classifier_label": current_labels,
|
1017
|
+
}
|
1018
|
+
)
|
1019
|
+
|
1020
|
+
return super().__call__([packed_example])
|