cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.4.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,13 @@
|
|
1
|
-
import
|
2
|
-
import random
|
3
|
-
from typing import Any, Dict, List, Optional
|
1
|
+
from typing import Any, Dict, List
|
4
2
|
|
5
3
|
import numpy as np
|
6
4
|
import torch
|
7
5
|
from torch.nn.utils.rnn import pad_sequence
|
8
6
|
from transformers.utils import logging
|
9
7
|
|
10
|
-
from cehrgpt.
|
11
|
-
DEMOGRAPHIC_PROMPT_SIZE,
|
12
|
-
collect_demographic_prompts_at_visits,
|
13
|
-
extract_time_interval_in_days,
|
14
|
-
extract_time_interval_in_hours,
|
15
|
-
is_att_token,
|
16
|
-
is_inpatient_att_token,
|
17
|
-
is_inpatient_hour_token,
|
18
|
-
is_visit_end,
|
19
|
-
random_slice_gpt_sequence,
|
20
|
-
)
|
8
|
+
from cehrgpt.data.cehrgpt_data_processor import CehrGptDataProcessor
|
21
9
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
22
10
|
|
23
|
-
TIME_TO_EVENT_MAX_TIME = 3650
|
24
|
-
INPATIENT_STAY_DURATION_LIMIT = 30
|
25
11
|
LOG = logging.get_logger("transformers")
|
26
12
|
|
27
13
|
|
@@ -30,13 +16,14 @@ class CehrGptDataCollator:
|
|
30
16
|
self,
|
31
17
|
tokenizer: CehrGptTokenizer,
|
32
18
|
max_length: int,
|
33
|
-
shuffle_records: bool = False,
|
34
19
|
include_values: bool = False,
|
20
|
+
shuffle_records: bool = False,
|
35
21
|
include_ttv_prediction: bool = False,
|
36
22
|
use_sub_time_tokenization: bool = False,
|
37
23
|
include_motor_time_to_event: bool = False,
|
38
24
|
motor_tte_vocab_size: int = 0,
|
39
25
|
motor_num_time_pieces: int = 8,
|
26
|
+
motor_sampling_probability: float = 0.5,
|
40
27
|
pretraining: bool = True,
|
41
28
|
include_demographics: bool = False,
|
42
29
|
add_linear_prob_token: bool = False,
|
@@ -47,13 +34,12 @@ class CehrGptDataCollator:
|
|
47
34
|
self.vs_token_id = tokenizer.vs_token_id
|
48
35
|
self.ve_token_id = tokenizer.ve_token_id
|
49
36
|
|
50
|
-
self.shuffle_records = shuffle_records
|
51
37
|
self.include_values = include_values
|
52
38
|
self.include_ttv_prediction = include_ttv_prediction
|
53
39
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
54
40
|
self.pretraining = pretraining
|
55
41
|
self.include_demographics = include_demographics
|
56
|
-
self.
|
42
|
+
self.motor_code_cache: Dict[str, List[str]] = dict()
|
57
43
|
|
58
44
|
# MOTOR TTE configuration
|
59
45
|
if include_motor_time_to_event:
|
@@ -66,8 +52,14 @@ class CehrGptDataCollator:
|
|
66
52
|
self.include_motor_time_to_event = include_motor_time_to_event
|
67
53
|
self.motor_tte_vocab_size = motor_tte_vocab_size
|
68
54
|
self.motor_num_time_pieces = motor_num_time_pieces
|
69
|
-
self.
|
70
|
-
|
55
|
+
self.motor_time_bins = (
|
56
|
+
self.tokenizer.get_motor_time_bins(motor_num_time_pieces)
|
57
|
+
if self.include_motor_time_to_event
|
58
|
+
else []
|
59
|
+
)
|
60
|
+
# Convert the time bins to seconds
|
61
|
+
self.motor_time_bins = [time_bin * 86400 for time_bin in self.motor_time_bins]
|
62
|
+
LOG.info("self.motor_time_bins: %s", self.motor_time_bins)
|
71
63
|
if self.use_sub_time_tokenization:
|
72
64
|
token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
|
73
65
|
if not token_to_time_token_mapping:
|
@@ -83,6 +75,18 @@ class CehrGptDataCollator:
|
|
83
75
|
list(token_to_time_token_mapping.values()), dtype=torch.int64
|
84
76
|
)
|
85
77
|
|
78
|
+
self.cehrgpt_data_processor = CehrGptDataProcessor(
|
79
|
+
tokenizer=tokenizer,
|
80
|
+
max_length=self.max_length,
|
81
|
+
shuffle_records=shuffle_records,
|
82
|
+
include_ttv_prediction=include_ttv_prediction,
|
83
|
+
include_values=include_values,
|
84
|
+
include_motor_time_to_event=include_motor_time_to_event,
|
85
|
+
motor_sampling_probability=motor_sampling_probability,
|
86
|
+
pretraining=pretraining,
|
87
|
+
add_linear_prob_token=add_linear_prob_token,
|
88
|
+
)
|
89
|
+
|
86
90
|
def _try_reverse_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
87
91
|
if not self.pretraining:
|
88
92
|
return torch.flip(tensor, dims=[-1])
|
@@ -95,30 +99,120 @@ class CehrGptDataCollator:
|
|
95
99
|
else:
|
96
100
|
return torch.tensor(features)
|
97
101
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
102
|
+
def create_time_to_event_tensors_ultra_optimized(
|
103
|
+
self, record: Dict[str, Any]
|
104
|
+
) -> Dict[str, Any]:
|
105
|
+
"""Ultra-optimized version using advanced vectorization techniques."""
|
106
|
+
motor_row_indices = record["motor_row_indices"]
|
107
|
+
motor_col_indices = record["motor_col_indices"]
|
108
|
+
motor_values = record["motor_values"]
|
109
|
+
motor_censor_times = record["motor_censor_times"]
|
110
|
+
|
111
|
+
if len(motor_row_indices) == 0:
|
112
|
+
# Handle empty case - use tuples for better performance
|
113
|
+
empty_shape = (
|
114
|
+
0,
|
115
|
+
self.motor_num_time_pieces,
|
116
|
+
self.tokenizer.motor_tte_vocab_size,
|
117
|
+
)
|
118
|
+
record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
|
119
|
+
record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
|
120
|
+
record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
|
121
|
+
return record
|
122
|
+
|
123
|
+
# Convert to numpy arrays once and get dimensions
|
124
|
+
motor_row_indices = np.asarray(motor_row_indices, dtype=np.int32)
|
125
|
+
motor_col_indices = np.asarray(motor_col_indices, dtype=np.int32)
|
126
|
+
motor_values = np.asarray(motor_values, dtype=np.float32)
|
127
|
+
motor_censor_times = np.asarray(motor_censor_times, dtype=np.float32)
|
128
|
+
|
129
|
+
n_tte_predictions = len(motor_censor_times) # More direct than unique()
|
130
|
+
vocab_size = self.tokenizer.motor_tte_vocab_size
|
131
|
+
n_time_pieces = self.motor_num_time_pieces
|
132
|
+
|
133
|
+
# Create time_vectors more efficiently without broadcasting copy
|
134
|
+
time_vectors = np.tile(
|
135
|
+
motor_censor_times[:, np.newaxis], (1, vocab_size)
|
136
|
+
).astype(np.float32)
|
137
|
+
event_indicators = np.zeros((n_tte_predictions, vocab_size), dtype=bool)
|
138
|
+
|
139
|
+
# Vectorized assignment (already optimal)
|
140
|
+
time_vectors[motor_row_indices, motor_col_indices] = motor_values
|
141
|
+
event_indicators[motor_row_indices, motor_col_indices] = True
|
142
|
+
|
143
|
+
# Early return if no predictions
|
144
|
+
if n_tte_predictions == 0:
|
145
|
+
empty_shape = (0, n_time_pieces, vocab_size)
|
146
|
+
record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
|
147
|
+
record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
|
148
|
+
record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
|
149
|
+
return record
|
150
|
+
|
151
|
+
# Cache motor_time_bins as numpy array to avoid repeated conversion
|
152
|
+
if not hasattr(self, "_motor_time_bins_array"):
|
153
|
+
self._motor_time_bins_array = np.asarray(
|
154
|
+
self.motor_time_bins, dtype=np.float32
|
155
|
+
)
|
156
|
+
|
157
|
+
motor_time_bins = self._motor_time_bins_array
|
158
|
+
start_times = motor_time_bins[:-1]
|
159
|
+
end_times = motor_time_bins[1:]
|
160
|
+
bin_widths = end_times - start_times # Pre-compute bin widths
|
161
|
+
|
162
|
+
# ELIMINATED TRANSPOSE: Compute directly in target shape (n_pred, n_bins, vocab)
|
163
|
+
# Reshape for broadcasting in target order
|
164
|
+
time_vectors_3d = time_vectors[:, np.newaxis, :] # (n_pred, 1, vocab)
|
165
|
+
event_indicators_3d = event_indicators[:, np.newaxis, :] # (n_pred, 1, vocab)
|
166
|
+
|
167
|
+
# Broadcast time bins to match target shape
|
168
|
+
start_times_broadcast = start_times[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
|
169
|
+
bin_widths_broadcast = bin_widths[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
|
170
|
+
|
171
|
+
# Compute directly in target shape (n_pred, n_bins, vocab)
|
172
|
+
time_diff = time_vectors_3d - start_times_broadcast
|
173
|
+
time_in_bin = np.clip(time_diff, 0, bin_widths_broadcast)
|
174
|
+
|
175
|
+
# Optimized mask computation
|
176
|
+
mask = time_in_bin > 0
|
177
|
+
|
178
|
+
# More efficient log computation with better constant
|
179
|
+
log_constant = 1e-8 # Better numerical stability than 1e-10
|
180
|
+
time_in_bin_log = np.where(
|
181
|
+
mask, np.log2(np.maximum(time_in_bin, log_constant)), -np.inf
|
182
|
+
)
|
183
|
+
|
184
|
+
# Event indicator computation in target shape
|
185
|
+
end_times_broadcast = motor_time_bins[1:][np.newaxis, :, np.newaxis]
|
186
|
+
time_in_range = (time_vectors_3d >= start_times_broadcast) & (
|
187
|
+
time_vectors_3d < end_times_broadcast
|
188
|
+
)
|
189
|
+
event_in_bin = event_indicators_3d & time_in_range
|
190
|
+
|
191
|
+
# Combined mask computation
|
192
|
+
final_mask = mask | event_in_bin
|
193
|
+
|
194
|
+
# Direct assignment - NO TRANSPOSE NEEDED!
|
195
|
+
record["motor_tte_times"] = time_in_bin_log
|
196
|
+
record["motor_tte_event_indicators"] = event_in_bin
|
197
|
+
record["motor_tte_masks"] = final_mask
|
198
|
+
|
199
|
+
# Validation (keep as is - important for correctness)
|
200
|
+
assert (
|
201
|
+
sum(record["motor_tte_task_indicators"]) == n_tte_predictions
|
202
|
+
), f'sum(record["motor_tte_task_indicators"]) == n_tte_predictions must be true'
|
203
|
+
|
204
|
+
# Clean up input data
|
205
|
+
del record["motor_row_indices"]
|
206
|
+
del record["motor_col_indices"]
|
207
|
+
del record["motor_values"]
|
208
|
+
|
209
|
+
return record
|
117
210
|
|
118
211
|
def __call__(self, examples):
|
119
|
-
|
120
|
-
|
121
|
-
|
212
|
+
|
213
|
+
if not getattr(self, "sample_packing", False):
|
214
|
+
examples = [self.cehrgpt_data_processor.transform(_) for _ in examples]
|
215
|
+
|
122
216
|
batch = {}
|
123
217
|
|
124
218
|
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
|
@@ -162,37 +256,31 @@ class CehrGptDataCollator:
|
|
162
256
|
f"batch['input_ids']: {batch['input_ids']} "
|
163
257
|
)
|
164
258
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
padding_value=0,
|
178
|
-
).to(torch.float32)
|
179
|
-
)
|
259
|
+
batch_ages = [
|
260
|
+
self._try_reverse_tensor(self._convert_to_tensor(example["ages"]))
|
261
|
+
for example in examples
|
262
|
+
]
|
263
|
+
# Pad sequences to the max length in the batch
|
264
|
+
batch["ages"] = self._try_reverse_tensor(
|
265
|
+
pad_sequence(
|
266
|
+
batch_ages,
|
267
|
+
batch_first=True,
|
268
|
+
padding_value=0,
|
269
|
+
).to(torch.int64)
|
270
|
+
)
|
180
271
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
padding_value=0,
|
194
|
-
).to(torch.int64)
|
195
|
-
)
|
272
|
+
batch_epoch_times = [
|
273
|
+
self._try_reverse_tensor(self._convert_to_tensor(example["epoch_times"]))
|
274
|
+
for example in examples
|
275
|
+
]
|
276
|
+
# Pad sequences to the max length in the batch
|
277
|
+
batch["epoch_times"] = self._try_reverse_tensor(
|
278
|
+
pad_sequence(
|
279
|
+
batch_epoch_times,
|
280
|
+
batch_first=True,
|
281
|
+
padding_value=0,
|
282
|
+
).to(torch.float32)
|
283
|
+
)
|
196
284
|
|
197
285
|
if self.pretraining:
|
198
286
|
batch["labels"] = torch.where(
|
@@ -233,54 +321,51 @@ class CehrGptDataCollator:
|
|
233
321
|
|
234
322
|
if self.include_motor_time_to_event:
|
235
323
|
examples_with_motor_tte = [
|
236
|
-
self.
|
324
|
+
self.create_time_to_event_tensors_ultra_optimized(_) for _ in examples
|
237
325
|
]
|
238
|
-
|
326
|
+
# print(f"Creating MOTOR TTE tensors took {time.time() - start} seconds")
|
327
|
+
motor_tte_times = [
|
239
328
|
self._try_reverse_tensor(
|
240
|
-
self._convert_to_tensor(example["
|
329
|
+
self._convert_to_tensor(example["motor_tte_times"])
|
241
330
|
)
|
242
331
|
for example in examples_with_motor_tte
|
243
332
|
]
|
244
|
-
|
333
|
+
motor_tte_event_indicators = [
|
245
334
|
self._try_reverse_tensor(
|
246
|
-
self._convert_to_tensor(example["
|
335
|
+
self._convert_to_tensor(example["motor_tte_event_indicators"])
|
247
336
|
)
|
248
337
|
for example in examples_with_motor_tte
|
249
338
|
]
|
250
|
-
|
339
|
+
motor_tte_task_indicators = [
|
251
340
|
self._try_reverse_tensor(
|
252
|
-
self._convert_to_tensor(example["
|
341
|
+
self._convert_to_tensor(example["motor_tte_task_indicators"])
|
253
342
|
)
|
254
343
|
for example in examples_with_motor_tte
|
255
344
|
]
|
256
|
-
|
345
|
+
motor_tte_masks = [
|
257
346
|
self._try_reverse_tensor(
|
258
|
-
self._convert_to_tensor(example["
|
347
|
+
self._convert_to_tensor(example["motor_tte_masks"])
|
259
348
|
)
|
260
349
|
for example in examples_with_motor_tte
|
261
350
|
]
|
262
351
|
|
263
|
-
|
264
|
-
batch_motor_time_to_event_vectors, dim=0
|
265
|
-
).to(torch.float32)
|
352
|
+
motor_tte_times = torch.concat(motor_tte_times, dim=0).to(torch.float32)
|
266
353
|
|
267
354
|
# If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
|
268
355
|
# we only create the labels when any example has more than one visit
|
269
|
-
if
|
356
|
+
if motor_tte_times.dim() <= 1:
|
270
357
|
LOG.warning(
|
271
358
|
"There are no MOTOR TTE labels generated for this batch "
|
272
359
|
"because every example in this batch only contains one visit."
|
273
360
|
)
|
274
361
|
else:
|
275
362
|
batch_size = len(examples)
|
276
|
-
length, num_time_pieces, motor_tte_vocab_size =
|
277
|
-
batch_motor_time_to_event_vectors.shape
|
278
|
-
)
|
363
|
+
length, num_time_pieces, motor_tte_vocab_size = motor_tte_times.shape
|
279
364
|
padded_length = batch_size - length % batch_size
|
280
|
-
batch["
|
365
|
+
batch["motor_tte_times"] = (
|
281
366
|
torch.concat(
|
282
367
|
[
|
283
|
-
|
368
|
+
motor_tte_times,
|
284
369
|
torch.full(
|
285
370
|
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
286
371
|
0.0,
|
@@ -293,13 +378,12 @@ class CehrGptDataCollator:
|
|
293
378
|
)
|
294
379
|
|
295
380
|
# Motor event indicators that indicate there is an event occurred in this time interval
|
296
|
-
|
297
|
-
batch_motor_event_indicators, dim=0
|
298
|
-
).to(torch.bool)
|
299
|
-
batch["motor_event_indicators"] = (
|
381
|
+
batch["motor_tte_event_indicators"] = (
|
300
382
|
torch.concat(
|
301
383
|
[
|
302
|
-
|
384
|
+
torch.concat(motor_tte_event_indicators, dim=0).to(
|
385
|
+
torch.bool
|
386
|
+
),
|
303
387
|
torch.full(
|
304
388
|
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
305
389
|
False,
|
@@ -312,27 +396,17 @@ class CehrGptDataCollator:
|
|
312
396
|
)
|
313
397
|
|
314
398
|
# Input to indicate whether the visit should be included for TTE predictions
|
315
|
-
|
316
|
-
|
399
|
+
batch["motor_tte_task_indicators"] = pad_sequence(
|
400
|
+
motor_tte_task_indicators,
|
401
|
+
batch_first=True,
|
402
|
+
padding_value=False,
|
317
403
|
).to(torch.bool)
|
318
|
-
batch["motor_time_to_event_to_include"] = (
|
319
|
-
torch.concat(
|
320
|
-
[
|
321
|
-
batch_motor_time_to_event_to_include,
|
322
|
-
torch.full((padded_length,), False),
|
323
|
-
],
|
324
|
-
dim=0,
|
325
|
-
).to(torch.bool)
|
326
|
-
).reshape((batch_size, -1))
|
327
404
|
|
328
405
|
# Motor time indicators that indicate whether there are neither clinical events nor censor events
|
329
|
-
|
330
|
-
batch_motor_time_indicators, dim=0
|
331
|
-
).to(torch.bool)
|
332
|
-
batch["motor_time_indicators"] = (
|
406
|
+
batch["motor_tte_masks"] = (
|
333
407
|
torch.concat(
|
334
408
|
[
|
335
|
-
|
409
|
+
torch.concat(motor_tte_masks, dim=0).to(torch.bool),
|
336
410
|
torch.full(
|
337
411
|
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
338
412
|
False,
|
@@ -438,572 +512,118 @@ class CehrGptDataCollator:
|
|
438
512
|
|
439
513
|
return batch
|
440
514
|
|
441
|
-
def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
442
|
-
"""
|
443
|
-
Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
|
444
|
-
|
445
|
-
Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
|
446
|
-
to the occurrence of future motor-related events.
|
447
|
-
|
448
|
-
Args:
|
449
|
-
record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
|
450
|
-
This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
|
451
|
-
|
452
|
-
Returns:
|
453
|
-
Dict[str, Any]: The updated input record with added keys:
|
454
|
-
- "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
|
455
|
-
- "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
|
456
|
-
"""
|
457
|
-
input_ids = record["input_ids"]
|
458
|
-
sample_packing = getattr(self, "sample_packing", False)
|
459
|
-
|
460
|
-
if isinstance(input_ids, torch.Tensor):
|
461
|
-
input_ids = input_ids.detach().tolist()
|
462
|
-
|
463
|
-
# This potentially contains packed samples, we need to handle that
|
464
|
-
packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
465
|
-
pad_indices = []
|
466
|
-
if sample_packing:
|
467
|
-
# We start from the first index
|
468
|
-
for i in range(len(packed_concept_ids)):
|
469
|
-
if packed_concept_ids[i] == self.tokenizer.pad_token:
|
470
|
-
# If we encounter consecutive pads, we should break out of the loop
|
471
|
-
if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
|
472
|
-
break
|
473
|
-
pad_indices.append(i)
|
474
|
-
|
475
|
-
# If we did not find a pad, that means the whole sequence belongs to one sample
|
476
|
-
if len(pad_indices) == 0:
|
477
|
-
pad_indices.append(len(packed_concept_ids))
|
478
|
-
|
479
|
-
timepiece_time_to_event_vectors = []
|
480
|
-
timepiece_event_indicators = []
|
481
|
-
timepiece_indicators = []
|
482
|
-
time_to_event_to_includes = []
|
483
|
-
|
484
|
-
for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
|
485
|
-
concept_ids = packed_concept_ids[start_index:end_index]
|
486
|
-
if concept_ids[0] == self.tokenizer.pad_token:
|
487
|
-
concept_ids.pop(0)
|
488
|
-
time_to_event_vectors = []
|
489
|
-
global_event_indicators = []
|
490
|
-
|
491
|
-
# First collect TTE data in reverse chronological order
|
492
|
-
censor_times = []
|
493
|
-
time_to_event_data: List[Dict[str, int]] = []
|
494
|
-
time_to_event_dict: Dict[str, int] = {}
|
495
|
-
time_to_event_to_include: List[bool] = []
|
496
|
-
next_future_visit_concepts = set()
|
497
|
-
time_interval = 0
|
498
|
-
|
499
|
-
# Reverse walk through concept_ids to calculate TTE from each [VE] point
|
500
|
-
for concept_id in reversed(concept_ids):
|
501
|
-
if is_visit_end(concept_id):
|
502
|
-
# Update TTE for existing concepts, or add new ones seen in this visit
|
503
|
-
for existing_concept_id in list(time_to_event_dict.keys()):
|
504
|
-
if existing_concept_id in next_future_visit_concepts:
|
505
|
-
time_to_event_dict[existing_concept_id] = time_interval
|
506
|
-
else:
|
507
|
-
time_to_event_dict[existing_concept_id] += time_interval
|
508
|
-
|
509
|
-
for next_concept_id in next_future_visit_concepts:
|
510
|
-
if next_concept_id not in time_to_event_dict:
|
511
|
-
time_to_event_dict[next_concept_id] = time_interval
|
512
|
-
|
513
|
-
# If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
|
514
|
-
# previous visit
|
515
|
-
time_to_event_to_include.append(time_interval > 0)
|
516
|
-
time_to_event_data.append(copy.deepcopy(time_to_event_dict))
|
517
|
-
# Record the censor time at the end of the visit
|
518
|
-
if censor_times:
|
519
|
-
censor_times.append(censor_times[-1] + time_interval)
|
520
|
-
else:
|
521
|
-
censor_times.append(time_interval)
|
522
|
-
time_interval = 0
|
523
|
-
next_future_visit_concepts.clear()
|
524
|
-
|
525
|
-
elif is_att_token(concept_id):
|
526
|
-
time_interval += extract_time_interval_in_days(concept_id)
|
527
|
-
|
528
|
-
elif self.tokenizer.is_motor_time_to_event_code(concept_id):
|
529
|
-
next_future_visit_concepts.add(concept_id)
|
530
|
-
|
531
|
-
if len(time_to_event_data) == 0:
|
532
|
-
LOG.info(
|
533
|
-
"Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
|
534
|
-
"It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
|
535
|
-
len(concept_ids),
|
536
|
-
concept_ids[-10:],
|
537
|
-
)
|
538
|
-
continue
|
539
|
-
|
540
|
-
# Reverse back to chronological order for final labels
|
541
|
-
time_to_event_data.reverse()
|
542
|
-
censor_times.reverse()
|
543
|
-
time_to_event_to_include.reverse()
|
544
|
-
|
545
|
-
for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
|
546
|
-
time_to_event_vector = np.full(
|
547
|
-
self.tokenizer.motor_tte_vocab_size,
|
548
|
-
fill_value=censor_time,
|
549
|
-
dtype=np.int32,
|
550
|
-
)
|
551
|
-
event_indicator = np.zeros(
|
552
|
-
self.tokenizer.motor_tte_vocab_size,
|
553
|
-
dtype=np.int32,
|
554
|
-
)
|
555
|
-
visit_token_ids = [
|
556
|
-
self.tokenizer.get_motor_token_id(concept_id)
|
557
|
-
for concept_id in visit_tte_data.keys()
|
558
|
-
]
|
559
|
-
visit_tte_values = list(visit_tte_data.values())
|
560
|
-
|
561
|
-
time_to_event_vector[visit_token_ids] = visit_tte_values
|
562
|
-
event_indicator[visit_token_ids] = 1 # not censored (event occurred)
|
563
|
-
|
564
|
-
time_to_event_vectors.append(time_to_event_vector)
|
565
|
-
global_event_indicators.append(event_indicator)
|
566
|
-
|
567
|
-
time_to_event_vectors = np.asarray(time_to_event_vectors)
|
568
|
-
global_event_indicators = np.asarray(global_event_indicators).astype(bool)
|
569
|
-
n_visits = len(time_to_event_vectors)
|
570
|
-
|
571
|
-
timepiece_time_to_event_vector = np.full(
|
572
|
-
(
|
573
|
-
self.motor_num_time_pieces,
|
574
|
-
n_visits,
|
575
|
-
self.tokenizer.motor_tte_vocab_size,
|
576
|
-
),
|
577
|
-
fill_value=0,
|
578
|
-
dtype=np.int32,
|
579
|
-
)
|
580
|
-
timepiece_event_indicator = np.zeros(
|
581
|
-
(
|
582
|
-
self.motor_num_time_pieces,
|
583
|
-
n_visits,
|
584
|
-
self.tokenizer.motor_tte_vocab_size,
|
585
|
-
),
|
586
|
-
dtype=bool,
|
587
|
-
)
|
588
|
-
timepiece_indicator = np.zeros(
|
589
|
-
(
|
590
|
-
self.motor_num_time_pieces,
|
591
|
-
n_visits,
|
592
|
-
self.tokenizer.motor_tte_vocab_size,
|
593
|
-
),
|
594
|
-
dtype=bool,
|
595
|
-
)
|
596
|
-
|
597
|
-
# Putting the event time and censor time into the corresponding time bins
|
598
|
-
for bin_num in range(self.motor_num_time_pieces):
|
599
|
-
start = self.motor_time_interval * bin_num
|
600
|
-
end = self.motor_time_interval * (bin_num + 1)
|
601
|
-
time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
|
602
|
-
timepiece_time_to_event_vector[bin_num] = time_in_bin
|
603
|
-
event_indicator = (
|
604
|
-
global_event_indicators
|
605
|
-
& (start <= time_to_event_vectors)
|
606
|
-
& (time_to_event_vectors < end)
|
607
|
-
)
|
608
|
-
timepiece_event_indicator[bin_num] = event_indicator
|
609
|
-
timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
|
610
|
-
|
611
|
-
timepiece_time_to_event_vectors.append(
|
612
|
-
timepiece_time_to_event_vector.swapaxes(0, 1)
|
613
|
-
)
|
614
|
-
timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
|
615
|
-
timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
|
616
|
-
time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
|
617
|
-
|
618
|
-
record["time_to_event_vectors"] = np.concatenate(
|
619
|
-
timepiece_time_to_event_vectors, axis=0
|
620
|
-
)
|
621
|
-
record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
|
622
|
-
record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
|
623
|
-
record["time_to_event_to_include"] = np.concatenate(
|
624
|
-
time_to_event_to_includes, axis=0
|
625
|
-
)
|
626
|
-
return record
|
627
|
-
|
628
|
-
def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
629
|
-
|
630
|
-
if not self.shuffle_records:
|
631
|
-
return record
|
632
|
-
|
633
|
-
if "record_ranks" not in record:
|
634
|
-
return record
|
635
|
-
|
636
|
-
sorting_column = record["record_ranks"]
|
637
|
-
random_order = np.random.rand(len(sorting_column))
|
638
|
-
|
639
|
-
if self.include_values:
|
640
|
-
iterator = zip(
|
641
|
-
sorting_column,
|
642
|
-
random_order,
|
643
|
-
record["input_ids"],
|
644
|
-
record["value_indicators"],
|
645
|
-
record["values"],
|
646
|
-
)
|
647
|
-
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
648
|
-
_, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
|
649
|
-
*list(sorted_list)
|
650
|
-
)
|
651
|
-
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
652
|
-
record["value_indicators"] = self._convert_to_tensor(
|
653
|
-
sorted_value_indicators
|
654
|
-
)
|
655
|
-
record["values"] = self._convert_to_tensor(sorted_values)
|
656
|
-
else:
|
657
|
-
iterator = zip(sorting_column, random_order, record["input_ids"])
|
658
|
-
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
659
|
-
_, _, sorted_input_ids = zip(*list(sorted_list))
|
660
|
-
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
661
|
-
return record
|
662
|
-
|
663
|
-
def generate_start_end_index(
|
664
|
-
self,
|
665
|
-
record: Dict[str, Any],
|
666
|
-
sample_packing: bool,
|
667
|
-
max_length_allowed: Optional[int] = None,
|
668
|
-
) -> Dict[str, Any]:
|
669
|
-
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
670
|
-
# concept_ids will be used to for time to event predictions and identifying the visit starts
|
671
|
-
max_length_allowed = (
|
672
|
-
self.max_length if max_length_allowed is None else max_length_allowed
|
673
|
-
)
|
674
|
-
input_ids = record["input_ids"]
|
675
|
-
if isinstance(input_ids, torch.Tensor):
|
676
|
-
input_ids = input_ids.detach().tolist()
|
677
|
-
concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
678
|
-
seq_length = len(record["input_ids"])
|
679
|
-
|
680
|
-
# Subtract one for the [END] token when sample_packing is not enabled
|
681
|
-
new_max_length = (
|
682
|
-
max_length_allowed - 1
|
683
|
-
if not sample_packing and self.pretraining
|
684
|
-
else max_length_allowed
|
685
|
-
)
|
686
|
-
|
687
|
-
if self.include_ttv_prediction:
|
688
|
-
record["time_to_visits"] = torch.concat(
|
689
|
-
[self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
|
690
|
-
)
|
691
|
-
|
692
|
-
# If linear token exists, we will use it, otherwise we default to the OOV token
|
693
|
-
linear_token_id = (
|
694
|
-
self.tokenizer.linear_token_id
|
695
|
-
if self.tokenizer.linear_token_id
|
696
|
-
else self.tokenizer.oov_token_id
|
697
|
-
)
|
698
|
-
eos_token = (
|
699
|
-
linear_token_id
|
700
|
-
if self.add_linear_prob_token
|
701
|
-
else self.tokenizer.end_token_id
|
702
|
-
)
|
703
|
-
|
704
|
-
# Return the record directly if the actual sequence length is less than the max sequence
|
705
|
-
if seq_length <= new_max_length:
|
706
|
-
if not sample_packing and self.pretraining:
|
707
|
-
record["input_ids"] = torch.concat(
|
708
|
-
[
|
709
|
-
self._convert_to_tensor(record["input_ids"]),
|
710
|
-
self._convert_to_tensor([eos_token]),
|
711
|
-
]
|
712
|
-
)
|
713
|
-
if "epoch_times" in record:
|
714
|
-
record["epoch_times"] = torch.concat(
|
715
|
-
[
|
716
|
-
self._convert_to_tensor(record["epoch_times"]),
|
717
|
-
self._convert_to_tensor([record["epoch_times"][-1]]),
|
718
|
-
]
|
719
|
-
)
|
720
|
-
if self.include_values:
|
721
|
-
record["value_indicators"] = torch.concat(
|
722
|
-
[
|
723
|
-
self._convert_to_tensor(record["value_indicators"]),
|
724
|
-
self._convert_to_tensor([False]),
|
725
|
-
]
|
726
|
-
).to(torch.bool)
|
727
|
-
record["values"] = torch.concat(
|
728
|
-
[
|
729
|
-
self._convert_to_tensor(record["values"]),
|
730
|
-
self._convert_to_tensor(
|
731
|
-
[self.tokenizer.pad_value_token_id]
|
732
|
-
),
|
733
|
-
]
|
734
|
-
)
|
735
|
-
if self.include_ttv_prediction:
|
736
|
-
record["time_to_visits"] = torch.concat(
|
737
|
-
[
|
738
|
-
record["time_to_visits"],
|
739
|
-
self._convert_to_tensor([-100.0]),
|
740
|
-
]
|
741
|
-
)
|
742
|
-
return record
|
743
|
-
|
744
|
-
if self.pretraining:
|
745
|
-
# There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
|
746
|
-
# prompt depending on the new starting point
|
747
|
-
if random.random() < 0.5 and not sample_packing:
|
748
|
-
start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
|
749
|
-
concept_ids, new_max_length
|
750
|
-
)
|
751
|
-
if start_index != end_index:
|
752
|
-
record["input_ids"] = self._convert_to_tensor(
|
753
|
-
record["input_ids"][start_index : end_index + 1]
|
754
|
-
)
|
755
|
-
if "epoch_times" in record:
|
756
|
-
record["epoch_times"] = self._convert_to_tensor(
|
757
|
-
record["epoch_times"][start_index : end_index + 1]
|
758
|
-
)
|
759
|
-
if self.include_values:
|
760
|
-
record["value_indicators"] = self._convert_to_tensor(
|
761
|
-
record["value_indicators"][start_index : end_index + 1]
|
762
|
-
).to(torch.bool)
|
763
|
-
record["values"] = self._convert_to_tensor(
|
764
|
-
record["values"][start_index : end_index + 1]
|
765
|
-
)
|
766
|
-
if self.include_ttv_prediction:
|
767
|
-
record["time_to_visits"] = self._convert_to_tensor(
|
768
|
-
self._convert_time_to_event(
|
769
|
-
concept_ids[start_index : end_index + 1]
|
770
|
-
)
|
771
|
-
)
|
772
|
-
return record
|
773
|
-
|
774
|
-
# The default employs a right truncation strategy, where the demographic prompt is reserved
|
775
|
-
end_index = new_max_length
|
776
|
-
for i in reversed(list(range(0, end_index))):
|
777
|
-
current_token = record["input_ids"][i]
|
778
|
-
if current_token == self.ve_token_id:
|
779
|
-
# Plus one because slicing is right exclusive
|
780
|
-
end_index = i + 1
|
781
|
-
break
|
782
|
-
|
783
|
-
record["input_ids"] = record["input_ids"][0:end_index]
|
784
|
-
|
785
|
-
# We want to make sure we take the subset of attention_mask in sample packing if this field is available
|
786
|
-
if sample_packing and "attention_mask" in record:
|
787
|
-
record["attention_mask"] = record["attention_mask"][0:end_index]
|
788
|
-
|
789
|
-
if sample_packing and "position_ids" in record:
|
790
|
-
record["position_ids"] = record["position_ids"][0:end_index]
|
791
|
-
|
792
|
-
if "epoch_times" in record:
|
793
|
-
record["epoch_times"] = self._convert_to_tensor(
|
794
|
-
record["epoch_times"][0:end_index]
|
795
|
-
)
|
796
|
-
|
797
|
-
if self.include_values:
|
798
|
-
record["value_indicators"] = self._convert_to_tensor(
|
799
|
-
record["value_indicators"][0:end_index]
|
800
|
-
).to(torch.bool)
|
801
|
-
record["values"] = self._convert_to_tensor(
|
802
|
-
record["values"][0:end_index]
|
803
|
-
)
|
804
|
-
if self.include_ttv_prediction:
|
805
|
-
record["time_to_visits"] = self._convert_to_tensor(
|
806
|
-
self._convert_time_to_event(concept_ids[0:end_index])
|
807
|
-
)
|
808
|
-
return record
|
809
|
-
else:
|
810
|
-
if self.include_demographics and not sample_packing:
|
811
|
-
# We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
|
812
|
-
demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
|
813
|
-
concept_ids
|
814
|
-
)
|
815
|
-
for token_index, demographic_prompt in demographic_prompts_at_visits:
|
816
|
-
if (
|
817
|
-
seq_length - token_index
|
818
|
-
<= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
|
819
|
-
):
|
820
|
-
demographic_tokens = self.tokenizer.encode(demographic_prompt)
|
821
|
-
record["input_ids"] = torch.concat(
|
822
|
-
[
|
823
|
-
self._convert_to_tensor(demographic_tokens),
|
824
|
-
self._convert_to_tensor(
|
825
|
-
record["input_ids"][token_index:seq_length]
|
826
|
-
),
|
827
|
-
]
|
828
|
-
)
|
829
|
-
if "epoch_times" in record:
|
830
|
-
record["epoch_times"] = torch.concat(
|
831
|
-
[
|
832
|
-
torch.zeros(
|
833
|
-
[record["epoch_times"][0]], dtype=torch.float32
|
834
|
-
),
|
835
|
-
self._convert_to_tensor(
|
836
|
-
record["epoch_times"][token_index:seq_length]
|
837
|
-
),
|
838
|
-
]
|
839
|
-
)
|
840
|
-
if self.include_values:
|
841
|
-
record["value_indicators"] = torch.concat(
|
842
|
-
[
|
843
|
-
torch.zeros(
|
844
|
-
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
845
|
-
).to(torch.bool),
|
846
|
-
self._convert_to_tensor(
|
847
|
-
record["value_indicators"][
|
848
|
-
token_index:seq_length
|
849
|
-
]
|
850
|
-
),
|
851
|
-
]
|
852
|
-
)
|
853
|
-
record["values"] = torch.concat(
|
854
|
-
[
|
855
|
-
torch.zeros(
|
856
|
-
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
857
|
-
)
|
858
|
-
.to(torch.int32)
|
859
|
-
.fill_(self.tokenizer.pad_value_token_id),
|
860
|
-
self._convert_to_tensor(
|
861
|
-
record["values"][token_index:seq_length]
|
862
|
-
),
|
863
|
-
]
|
864
|
-
)
|
865
|
-
if self.include_ttv_prediction:
|
866
|
-
record["time_to_visits"] = torch.concat(
|
867
|
-
[
|
868
|
-
torch.zeros(
|
869
|
-
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
870
|
-
)
|
871
|
-
.to(torch.float32)
|
872
|
-
.fill_(-100.0),
|
873
|
-
record["time_to_visits"][token_index:seq_length],
|
874
|
-
]
|
875
|
-
)
|
876
|
-
break
|
877
|
-
else:
|
878
|
-
start_index = max(seq_length - new_max_length, 0)
|
879
|
-
end_index = seq_length
|
880
|
-
for i in range(start_index, end_index):
|
881
|
-
current_token = record["input_ids"][i]
|
882
|
-
if current_token == self.vs_token_id:
|
883
|
-
record["input_ids"] = record["input_ids"][i:end_index]
|
884
|
-
if sample_packing and "attention_mask" in record:
|
885
|
-
record["attention_mask"] = record["attention_mask"][
|
886
|
-
i:end_index
|
887
|
-
]
|
888
|
-
if sample_packing and "position_ids" in record:
|
889
|
-
record["position_ids"] = record["position_ids"][i:end_index]
|
890
|
-
|
891
|
-
if "epoch_times" in record:
|
892
|
-
record["epoch_times"] = self._convert_to_tensor(
|
893
|
-
record["epoch_times"][i:end_index]
|
894
|
-
)
|
895
|
-
if self.include_values:
|
896
|
-
record["value_indicators"] = record["value_indicators"][
|
897
|
-
i:end_index
|
898
|
-
]
|
899
|
-
record["values"] = record["values"][i:end_index]
|
900
|
-
if self.include_ttv_prediction:
|
901
|
-
record["time_to_visits"] = record["time_to_visits"][
|
902
|
-
i:end_index
|
903
|
-
]
|
904
|
-
break
|
905
|
-
|
906
|
-
# This could happen when the last visit contains more than new_max_length number of tokens
|
907
|
-
# We simply take the last new_max_length number of tokens from the patient sequence
|
908
|
-
if len(record["input_ids"]) > new_max_length:
|
909
|
-
record["input_ids"] = record["input_ids"][-new_max_length:]
|
910
|
-
if sample_packing and "attention_mask" in record:
|
911
|
-
record["attention_mask"] = record["attention_mask"][
|
912
|
-
-new_max_length:
|
913
|
-
]
|
914
|
-
if sample_packing and "position_ids" in record:
|
915
|
-
record["position_ids"] = record["position_ids"][-new_max_length:]
|
916
|
-
if "epoch_times" in record:
|
917
|
-
record["epoch_times"] = self._convert_to_tensor(
|
918
|
-
record["epoch_times"][-new_max_length:]
|
919
|
-
)
|
920
|
-
if self.include_values:
|
921
|
-
record["value_indicators"] = record["value_indicators"][
|
922
|
-
-new_max_length:
|
923
|
-
]
|
924
|
-
record["values"] = record["values"][-new_max_length:]
|
925
|
-
if self.include_ttv_prediction:
|
926
|
-
record["time_to_visits"] = record["time_to_visits"][
|
927
|
-
-new_max_length:
|
928
|
-
]
|
929
|
-
|
930
|
-
return record
|
931
|
-
|
932
515
|
|
933
516
|
class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
934
517
|
def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
|
935
518
|
self.max_tokens_per_batch = max_tokens
|
936
519
|
self.max_position_embeddings = max_position_embeddings
|
937
520
|
self.sample_packing = True
|
938
|
-
self.add_end_token_in_sample_packing = kwargs.pop(
|
939
|
-
"add_end_token_in_sample_packing", False
|
940
|
-
)
|
941
521
|
super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
|
522
|
+
self.cehrgpt_data_processor.max_length = self.max_position_embeddings
|
942
523
|
|
943
524
|
def __call__(self, examples):
|
944
525
|
current_input_ids = []
|
945
526
|
current_attention_mask = []
|
946
|
-
|
527
|
+
current_ages = []
|
528
|
+
current_epoch_times = []
|
947
529
|
current_value_indicators = []
|
948
530
|
current_values = []
|
949
531
|
|
532
|
+
# MOTOR inputs
|
533
|
+
current_motor_censor_times = []
|
534
|
+
current_motor_row_indices = []
|
535
|
+
current_motor_col_indices = []
|
536
|
+
current_motor_values = []
|
537
|
+
current_motor_tte_task_indicators = []
|
538
|
+
|
950
539
|
# Demographics
|
951
540
|
current_person_ids = []
|
952
541
|
current_index_dates = []
|
953
542
|
|
954
543
|
# Binary classification inputs
|
955
|
-
|
544
|
+
current_prediction_ages = []
|
956
545
|
current_labels = []
|
957
546
|
|
958
547
|
for idx, example in enumerate(examples):
|
959
|
-
|
960
|
-
# We only add an end token if the patient sequence could fit in the entire context window
|
961
|
-
add_end_token = (
|
962
|
-
len(example["input_ids"]) <= self.max_position_embeddings
|
963
|
-
and self.add_end_token_in_sample_packing
|
964
|
-
)
|
965
|
-
# If the sample length exceeds the model's capacity, truncate this example
|
966
|
-
if len(example["input_ids"]) > self.max_position_embeddings:
|
967
|
-
example = self.generate_start_end_index(
|
968
|
-
example, False, self.max_position_embeddings
|
969
|
-
)
|
970
|
-
|
971
|
-
add_eos_token = add_end_token | self.add_linear_prob_token
|
972
|
-
additional_tokens = []
|
973
|
-
if add_end_token:
|
974
|
-
additional_tokens.append(self.tokenizer.end_token_id)
|
975
|
-
elif self.add_linear_prob_token:
|
976
|
-
# Backward compatible
|
977
|
-
linear_prob_token_id = (
|
978
|
-
self.tokenizer.linear_token_id
|
979
|
-
if self.tokenizer.linear_token_id is not None
|
980
|
-
else self.tokenizer.oov_token_id
|
981
|
-
)
|
982
|
-
additional_tokens.append(linear_prob_token_id)
|
983
|
-
additional_tokens.append(self.tokenizer.pad_token_id)
|
548
|
+
example = self.cehrgpt_data_processor.transform(example)
|
984
549
|
input_ids = example["input_ids"]
|
985
550
|
# We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
|
986
551
|
# If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
|
987
552
|
# we only add [PAD] token to the end of the sequence because it's not finished
|
988
|
-
current_input_ids.extend(list(input_ids) +
|
989
|
-
current_attention_mask.extend(
|
990
|
-
|
553
|
+
current_input_ids.extend(list(input_ids) + [self.tokenizer.pad_token_id])
|
554
|
+
current_attention_mask.extend(np.ones_like(input_ids).tolist() + [0])
|
555
|
+
|
556
|
+
ages = (
|
557
|
+
example["ages"].tolist()
|
558
|
+
if isinstance(example["ages"], torch.Tensor)
|
559
|
+
else list(example["ages"])
|
991
560
|
)
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
)
|
561
|
+
current_ages.extend(ages + [max(ages)])
|
562
|
+
|
563
|
+
epoch_times = (
|
564
|
+
example["epoch_times"].tolist()
|
565
|
+
if isinstance(example["epoch_times"], torch.Tensor)
|
566
|
+
else list(example["epoch_times"])
|
999
567
|
)
|
568
|
+
current_epoch_times.extend(epoch_times + [max(epoch_times)])
|
569
|
+
|
1000
570
|
if self.include_values:
|
1001
571
|
current_value_indicators.extend(
|
1002
|
-
|
572
|
+
(
|
573
|
+
example["value_indicators"].tolist()
|
574
|
+
if isinstance(example["value_indicators"], torch.Tensor)
|
575
|
+
else list(example["value_indicators"])
|
576
|
+
)
|
577
|
+
+ [False]
|
1003
578
|
)
|
1004
579
|
current_values.extend(
|
1005
|
-
|
1006
|
-
|
580
|
+
(
|
581
|
+
example["values"].tolist()
|
582
|
+
if isinstance(example["values"], torch.Tensor)
|
583
|
+
else list(example["values"])
|
584
|
+
)
|
585
|
+
+ [self.tokenizer.pad_value_token_id]
|
586
|
+
)
|
587
|
+
|
588
|
+
if self.include_motor_time_to_event:
|
589
|
+
current_max_motor_row_index = len(np.unique(current_motor_row_indices))
|
590
|
+
motor_row_indices = (
|
591
|
+
example["motor_row_indices"].tolist()
|
592
|
+
if isinstance(example["motor_row_indices"], torch.Tensor)
|
593
|
+
else list(example["motor_row_indices"])
|
594
|
+
)
|
595
|
+
current_motor_row_indices.extend(
|
596
|
+
list(
|
597
|
+
map(
|
598
|
+
lambda offset: offset + current_max_motor_row_index,
|
599
|
+
motor_row_indices,
|
600
|
+
)
|
601
|
+
)
|
602
|
+
)
|
603
|
+
current_motor_col_indices.extend(
|
604
|
+
example["motor_col_indices"].tolist()
|
605
|
+
if isinstance(example["motor_col_indices"], torch.Tensor)
|
606
|
+
else list(example["motor_col_indices"])
|
607
|
+
)
|
608
|
+
current_motor_values.extend(
|
609
|
+
example["motor_values"].tolist()
|
610
|
+
if isinstance(example["motor_values"], torch.Tensor)
|
611
|
+
else list(example["motor_values"])
|
612
|
+
)
|
613
|
+
current_motor_censor_times.extend(
|
614
|
+
example["motor_censor_times"].tolist()
|
615
|
+
if isinstance(example["motor_censor_times"], torch.Tensor)
|
616
|
+
else list(example["motor_censor_times"])
|
617
|
+
)
|
618
|
+
current_motor_tte_task_indicators.extend(
|
619
|
+
(
|
620
|
+
example["motor_tte_task_indicators"].tolist()
|
621
|
+
if isinstance(
|
622
|
+
example["motor_tte_task_indicators"], torch.Tensor
|
623
|
+
)
|
624
|
+
else list(example["motor_tte_task_indicators"])
|
625
|
+
)
|
626
|
+
+ [False]
|
1007
627
|
)
|
1008
628
|
|
1009
629
|
if "person_id" in example:
|
@@ -1013,7 +633,7 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
1013
633
|
current_index_dates.append(example["index_date"])
|
1014
634
|
|
1015
635
|
if "age_at_index" in example:
|
1016
|
-
|
636
|
+
current_prediction_ages.append(example["age_at_index"])
|
1017
637
|
|
1018
638
|
if "classifier_label" in example:
|
1019
639
|
current_labels.append(example["classifier_label"])
|
@@ -1025,20 +645,33 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
1025
645
|
packed_example = {
|
1026
646
|
"input_ids": current_input_ids,
|
1027
647
|
"attention_mask": current_attention_mask,
|
1028
|
-
"
|
648
|
+
"ages": current_ages,
|
649
|
+
"epoch_times": current_epoch_times,
|
1029
650
|
}
|
651
|
+
|
1030
652
|
if self.include_values:
|
1031
|
-
packed_example.update(
|
1032
|
-
|
653
|
+
packed_example.update(
|
654
|
+
{"value_indicators": current_value_indicators, "values": current_values}
|
655
|
+
)
|
656
|
+
if self.include_motor_time_to_event:
|
657
|
+
packed_example.update(
|
658
|
+
{
|
659
|
+
"motor_censor_times": current_motor_censor_times,
|
660
|
+
"motor_row_indices": current_motor_row_indices,
|
661
|
+
"motor_col_indices": current_motor_col_indices,
|
662
|
+
"motor_values": current_motor_values,
|
663
|
+
"motor_tte_task_indicators": current_motor_tte_task_indicators,
|
664
|
+
}
|
665
|
+
)
|
1033
666
|
|
1034
667
|
if current_labels:
|
1035
668
|
packed_example.update(
|
1036
669
|
{
|
1037
670
|
"person_id": current_person_ids,
|
1038
671
|
"index_date": current_index_dates,
|
1039
|
-
"age_at_index":
|
672
|
+
"age_at_index": current_prediction_ages,
|
1040
673
|
"classifier_label": current_labels,
|
1041
674
|
}
|
1042
675
|
)
|
1043
|
-
|
676
|
+
# print(f"Packing examples took {time.time() - start} seconds")
|
1044
677
|
return super().__call__([packed_example])
|