cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,13 @@
|
|
1
|
-
import
|
2
|
-
import random
|
3
|
-
from typing import Any, Dict, List, Optional
|
1
|
+
from typing import Any, Dict, List
|
4
2
|
|
5
3
|
import numpy as np
|
6
4
|
import torch
|
7
5
|
from torch.nn.utils.rnn import pad_sequence
|
8
6
|
from transformers.utils import logging
|
9
7
|
|
10
|
-
from cehrgpt.
|
11
|
-
DEMOGRAPHIC_PROMPT_SIZE,
|
12
|
-
collect_demographic_prompts_at_visits,
|
13
|
-
extract_time_interval_in_days,
|
14
|
-
extract_time_interval_in_hours,
|
15
|
-
is_att_token,
|
16
|
-
is_inpatient_att_token,
|
17
|
-
is_inpatient_hour_token,
|
18
|
-
is_visit_end,
|
19
|
-
random_slice_gpt_sequence,
|
20
|
-
)
|
8
|
+
from cehrgpt.data.cehrgpt_data_processor import CehrGptDataProcessor
|
21
9
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
22
10
|
|
23
|
-
TIME_TO_EVENT_MAX_TIME = 3650
|
24
|
-
INPATIENT_STAY_DURATION_LIMIT = 30
|
25
11
|
LOG = logging.get_logger("transformers")
|
26
12
|
|
27
13
|
|
@@ -30,13 +16,14 @@ class CehrGptDataCollator:
|
|
30
16
|
self,
|
31
17
|
tokenizer: CehrGptTokenizer,
|
32
18
|
max_length: int,
|
33
|
-
shuffle_records: bool = False,
|
34
19
|
include_values: bool = False,
|
20
|
+
shuffle_records: bool = False,
|
35
21
|
include_ttv_prediction: bool = False,
|
36
22
|
use_sub_time_tokenization: bool = False,
|
37
23
|
include_motor_time_to_event: bool = False,
|
38
24
|
motor_tte_vocab_size: int = 0,
|
39
25
|
motor_num_time_pieces: int = 8,
|
26
|
+
motor_sampling_probability: float = 0.5,
|
40
27
|
pretraining: bool = True,
|
41
28
|
include_demographics: bool = False,
|
42
29
|
add_linear_prob_token: bool = False,
|
@@ -47,13 +34,12 @@ class CehrGptDataCollator:
|
|
47
34
|
self.vs_token_id = tokenizer.vs_token_id
|
48
35
|
self.ve_token_id = tokenizer.ve_token_id
|
49
36
|
|
50
|
-
self.shuffle_records = shuffle_records
|
51
37
|
self.include_values = include_values
|
52
38
|
self.include_ttv_prediction = include_ttv_prediction
|
53
39
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
54
40
|
self.pretraining = pretraining
|
55
41
|
self.include_demographics = include_demographics
|
56
|
-
self.
|
42
|
+
self.motor_code_cache: Dict[str, List[str]] = dict()
|
57
43
|
|
58
44
|
# MOTOR TTE configuration
|
59
45
|
if include_motor_time_to_event:
|
@@ -66,8 +52,14 @@ class CehrGptDataCollator:
|
|
66
52
|
self.include_motor_time_to_event = include_motor_time_to_event
|
67
53
|
self.motor_tte_vocab_size = motor_tte_vocab_size
|
68
54
|
self.motor_num_time_pieces = motor_num_time_pieces
|
69
|
-
self.
|
70
|
-
|
55
|
+
self.motor_time_bins = (
|
56
|
+
self.tokenizer.get_motor_time_bins(motor_num_time_pieces)
|
57
|
+
if self.include_motor_time_to_event
|
58
|
+
else []
|
59
|
+
)
|
60
|
+
# Convert the time bins to seconds
|
61
|
+
self.motor_time_bins = [time_bin * 86400 for time_bin in self.motor_time_bins]
|
62
|
+
LOG.info("self.motor_time_bins: %s", self.motor_time_bins)
|
71
63
|
if self.use_sub_time_tokenization:
|
72
64
|
token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
|
73
65
|
if not token_to_time_token_mapping:
|
@@ -83,6 +75,18 @@ class CehrGptDataCollator:
|
|
83
75
|
list(token_to_time_token_mapping.values()), dtype=torch.int64
|
84
76
|
)
|
85
77
|
|
78
|
+
self.cehrgpt_data_processor = CehrGptDataProcessor(
|
79
|
+
tokenizer=tokenizer,
|
80
|
+
max_length=self.max_length,
|
81
|
+
shuffle_records=shuffle_records,
|
82
|
+
include_ttv_prediction=include_ttv_prediction,
|
83
|
+
include_values=include_values,
|
84
|
+
include_motor_time_to_event=include_motor_time_to_event,
|
85
|
+
motor_sampling_probability=motor_sampling_probability,
|
86
|
+
pretraining=pretraining,
|
87
|
+
add_linear_prob_token=add_linear_prob_token,
|
88
|
+
)
|
89
|
+
|
86
90
|
def _try_reverse_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
87
91
|
if not self.pretraining:
|
88
92
|
return torch.flip(tensor, dims=[-1])
|
@@ -95,30 +99,120 @@ class CehrGptDataCollator:
|
|
95
99
|
else:
|
96
100
|
return torch.tensor(features)
|
97
101
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
102
|
+
def create_time_to_event_tensors_ultra_optimized(
|
103
|
+
self, record: Dict[str, Any]
|
104
|
+
) -> Dict[str, Any]:
|
105
|
+
"""Ultra-optimized version using advanced vectorization techniques."""
|
106
|
+
motor_row_indices = record["motor_row_indices"]
|
107
|
+
motor_col_indices = record["motor_col_indices"]
|
108
|
+
motor_values = record["motor_values"]
|
109
|
+
motor_censor_times = record["motor_censor_times"]
|
110
|
+
|
111
|
+
if len(motor_row_indices) == 0:
|
112
|
+
# Handle empty case - use tuples for better performance
|
113
|
+
empty_shape = (
|
114
|
+
0,
|
115
|
+
self.motor_num_time_pieces,
|
116
|
+
self.tokenizer.motor_tte_vocab_size,
|
117
|
+
)
|
118
|
+
record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
|
119
|
+
record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
|
120
|
+
record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
|
121
|
+
return record
|
122
|
+
|
123
|
+
# Convert to numpy arrays once and get dimensions
|
124
|
+
motor_row_indices = np.asarray(motor_row_indices, dtype=np.int32)
|
125
|
+
motor_col_indices = np.asarray(motor_col_indices, dtype=np.int32)
|
126
|
+
motor_values = np.asarray(motor_values, dtype=np.float32)
|
127
|
+
motor_censor_times = np.asarray(motor_censor_times, dtype=np.float32)
|
128
|
+
|
129
|
+
n_tte_predictions = len(motor_censor_times) # More direct than unique()
|
130
|
+
vocab_size = self.tokenizer.motor_tte_vocab_size
|
131
|
+
n_time_pieces = self.motor_num_time_pieces
|
132
|
+
|
133
|
+
# Create time_vectors more efficiently without broadcasting copy
|
134
|
+
time_vectors = np.tile(
|
135
|
+
motor_censor_times[:, np.newaxis], (1, vocab_size)
|
136
|
+
).astype(np.float32)
|
137
|
+
event_indicators = np.zeros((n_tte_predictions, vocab_size), dtype=bool)
|
138
|
+
|
139
|
+
# Vectorized assignment (already optimal)
|
140
|
+
time_vectors[motor_row_indices, motor_col_indices] = motor_values
|
141
|
+
event_indicators[motor_row_indices, motor_col_indices] = True
|
142
|
+
|
143
|
+
# Early return if no predictions
|
144
|
+
if n_tte_predictions == 0:
|
145
|
+
empty_shape = (0, n_time_pieces, vocab_size)
|
146
|
+
record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
|
147
|
+
record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
|
148
|
+
record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
|
149
|
+
return record
|
150
|
+
|
151
|
+
# Cache motor_time_bins as numpy array to avoid repeated conversion
|
152
|
+
if not hasattr(self, "_motor_time_bins_array"):
|
153
|
+
self._motor_time_bins_array = np.asarray(
|
154
|
+
self.motor_time_bins, dtype=np.float32
|
155
|
+
)
|
156
|
+
|
157
|
+
motor_time_bins = self._motor_time_bins_array
|
158
|
+
start_times = motor_time_bins[:-1]
|
159
|
+
end_times = motor_time_bins[1:]
|
160
|
+
bin_widths = end_times - start_times # Pre-compute bin widths
|
161
|
+
|
162
|
+
# ELIMINATED TRANSPOSE: Compute directly in target shape (n_pred, n_bins, vocab)
|
163
|
+
# Reshape for broadcasting in target order
|
164
|
+
time_vectors_3d = time_vectors[:, np.newaxis, :] # (n_pred, 1, vocab)
|
165
|
+
event_indicators_3d = event_indicators[:, np.newaxis, :] # (n_pred, 1, vocab)
|
166
|
+
|
167
|
+
# Broadcast time bins to match target shape
|
168
|
+
start_times_broadcast = start_times[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
|
169
|
+
bin_widths_broadcast = bin_widths[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
|
170
|
+
|
171
|
+
# Compute directly in target shape (n_pred, n_bins, vocab)
|
172
|
+
time_diff = time_vectors_3d - start_times_broadcast
|
173
|
+
time_in_bin = np.clip(time_diff, 0, bin_widths_broadcast)
|
174
|
+
|
175
|
+
# Optimized mask computation
|
176
|
+
mask = time_in_bin > 0
|
177
|
+
|
178
|
+
# More efficient log computation with better constant
|
179
|
+
log_constant = 1e-8 # Better numerical stability than 1e-10
|
180
|
+
time_in_bin_log = np.where(
|
181
|
+
mask, np.log2(np.maximum(time_in_bin, log_constant)), -np.inf
|
182
|
+
)
|
183
|
+
|
184
|
+
# Event indicator computation in target shape
|
185
|
+
end_times_broadcast = motor_time_bins[1:][np.newaxis, :, np.newaxis]
|
186
|
+
time_in_range = (time_vectors_3d >= start_times_broadcast) & (
|
187
|
+
time_vectors_3d < end_times_broadcast
|
188
|
+
)
|
189
|
+
event_in_bin = event_indicators_3d & time_in_range
|
190
|
+
|
191
|
+
# Combined mask computation
|
192
|
+
final_mask = mask | event_in_bin
|
193
|
+
|
194
|
+
# Direct assignment - NO TRANSPOSE NEEDED!
|
195
|
+
record["motor_tte_times"] = time_in_bin_log
|
196
|
+
record["motor_tte_event_indicators"] = event_in_bin
|
197
|
+
record["motor_tte_masks"] = final_mask
|
198
|
+
|
199
|
+
# Validation (keep as is - important for correctness)
|
200
|
+
assert (
|
201
|
+
sum(record["motor_tte_task_indicators"]) == n_tte_predictions
|
202
|
+
), f'sum(record["motor_tte_task_indicators"]) == n_tte_predictions must be true'
|
203
|
+
|
204
|
+
# Clean up input data
|
205
|
+
del record["motor_row_indices"]
|
206
|
+
del record["motor_col_indices"]
|
207
|
+
del record["motor_values"]
|
208
|
+
|
209
|
+
return record
|
117
210
|
|
118
211
|
def __call__(self, examples):
|
119
|
-
|
120
|
-
|
121
|
-
|
212
|
+
|
213
|
+
if not getattr(self, "sample_packing", False):
|
214
|
+
examples = [self.cehrgpt_data_processor.transform(_) for _ in examples]
|
215
|
+
|
122
216
|
batch = {}
|
123
217
|
|
124
218
|
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
|
@@ -162,21 +256,31 @@ class CehrGptDataCollator:
|
|
162
256
|
f"batch['input_ids']: {batch['input_ids']} "
|
163
257
|
)
|
164
258
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
)
|
259
|
+
batch_ages = [
|
260
|
+
self._try_reverse_tensor(self._convert_to_tensor(example["ages"]))
|
261
|
+
for example in examples
|
262
|
+
]
|
263
|
+
# Pad sequences to the max length in the batch
|
264
|
+
batch["ages"] = self._try_reverse_tensor(
|
265
|
+
pad_sequence(
|
266
|
+
batch_ages,
|
267
|
+
batch_first=True,
|
268
|
+
padding_value=0,
|
269
|
+
).to(torch.int64)
|
270
|
+
)
|
271
|
+
|
272
|
+
batch_epoch_times = [
|
273
|
+
self._try_reverse_tensor(self._convert_to_tensor(example["epoch_times"]))
|
274
|
+
for example in examples
|
275
|
+
]
|
276
|
+
# Pad sequences to the max length in the batch
|
277
|
+
batch["epoch_times"] = self._try_reverse_tensor(
|
278
|
+
pad_sequence(
|
279
|
+
batch_epoch_times,
|
280
|
+
batch_first=True,
|
281
|
+
padding_value=0,
|
282
|
+
).to(torch.float32)
|
283
|
+
)
|
180
284
|
|
181
285
|
if self.pretraining:
|
182
286
|
batch["labels"] = torch.where(
|
@@ -217,54 +321,51 @@ class CehrGptDataCollator:
|
|
217
321
|
|
218
322
|
if self.include_motor_time_to_event:
|
219
323
|
examples_with_motor_tte = [
|
220
|
-
self.
|
324
|
+
self.create_time_to_event_tensors_ultra_optimized(_) for _ in examples
|
221
325
|
]
|
222
|
-
|
326
|
+
# print(f"Creating MOTOR TTE tensors took {time.time() - start} seconds")
|
327
|
+
motor_tte_times = [
|
223
328
|
self._try_reverse_tensor(
|
224
|
-
self._convert_to_tensor(example["
|
329
|
+
self._convert_to_tensor(example["motor_tte_times"])
|
225
330
|
)
|
226
331
|
for example in examples_with_motor_tte
|
227
332
|
]
|
228
|
-
|
333
|
+
motor_tte_event_indicators = [
|
229
334
|
self._try_reverse_tensor(
|
230
|
-
self._convert_to_tensor(example["
|
335
|
+
self._convert_to_tensor(example["motor_tte_event_indicators"])
|
231
336
|
)
|
232
337
|
for example in examples_with_motor_tte
|
233
338
|
]
|
234
|
-
|
339
|
+
motor_tte_task_indicators = [
|
235
340
|
self._try_reverse_tensor(
|
236
|
-
self._convert_to_tensor(example["
|
341
|
+
self._convert_to_tensor(example["motor_tte_task_indicators"])
|
237
342
|
)
|
238
343
|
for example in examples_with_motor_tte
|
239
344
|
]
|
240
|
-
|
345
|
+
motor_tte_masks = [
|
241
346
|
self._try_reverse_tensor(
|
242
|
-
self._convert_to_tensor(example["
|
347
|
+
self._convert_to_tensor(example["motor_tte_masks"])
|
243
348
|
)
|
244
349
|
for example in examples_with_motor_tte
|
245
350
|
]
|
246
351
|
|
247
|
-
|
248
|
-
batch_motor_time_to_event_vectors, dim=0
|
249
|
-
).to(torch.float32)
|
352
|
+
motor_tte_times = torch.concat(motor_tte_times, dim=0).to(torch.float32)
|
250
353
|
|
251
354
|
# If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
|
252
355
|
# we only create the labels when any example has more than one visit
|
253
|
-
if
|
356
|
+
if motor_tte_times.dim() <= 1:
|
254
357
|
LOG.warning(
|
255
358
|
"There are no MOTOR TTE labels generated for this batch "
|
256
359
|
"because every example in this batch only contains one visit."
|
257
360
|
)
|
258
361
|
else:
|
259
362
|
batch_size = len(examples)
|
260
|
-
length, num_time_pieces, motor_tte_vocab_size =
|
261
|
-
batch_motor_time_to_event_vectors.shape
|
262
|
-
)
|
363
|
+
length, num_time_pieces, motor_tte_vocab_size = motor_tte_times.shape
|
263
364
|
padded_length = batch_size - length % batch_size
|
264
|
-
batch["
|
365
|
+
batch["motor_tte_times"] = (
|
265
366
|
torch.concat(
|
266
367
|
[
|
267
|
-
|
368
|
+
motor_tte_times,
|
268
369
|
torch.full(
|
269
370
|
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
270
371
|
0.0,
|
@@ -277,13 +378,12 @@ class CehrGptDataCollator:
|
|
277
378
|
)
|
278
379
|
|
279
380
|
# Motor event indicators that indicate there is an event occurred in this time interval
|
280
|
-
|
281
|
-
batch_motor_event_indicators, dim=0
|
282
|
-
).to(torch.bool)
|
283
|
-
batch["motor_event_indicators"] = (
|
381
|
+
batch["motor_tte_event_indicators"] = (
|
284
382
|
torch.concat(
|
285
383
|
[
|
286
|
-
|
384
|
+
torch.concat(motor_tte_event_indicators, dim=0).to(
|
385
|
+
torch.bool
|
386
|
+
),
|
287
387
|
torch.full(
|
288
388
|
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
289
389
|
False,
|
@@ -296,27 +396,17 @@ class CehrGptDataCollator:
|
|
296
396
|
)
|
297
397
|
|
298
398
|
# Input to indicate whether the visit should be included for TTE predictions
|
299
|
-
|
300
|
-
|
399
|
+
batch["motor_tte_task_indicators"] = pad_sequence(
|
400
|
+
motor_tte_task_indicators,
|
401
|
+
batch_first=True,
|
402
|
+
padding_value=False,
|
301
403
|
).to(torch.bool)
|
302
|
-
batch["motor_time_to_event_to_include"] = (
|
303
|
-
torch.concat(
|
304
|
-
[
|
305
|
-
batch_motor_time_to_event_to_include,
|
306
|
-
torch.full((padded_length,), False),
|
307
|
-
],
|
308
|
-
dim=0,
|
309
|
-
).to(torch.bool)
|
310
|
-
).reshape((batch_size, -1))
|
311
404
|
|
312
405
|
# Motor time indicators that indicate whether there are neither clinical events nor censor events
|
313
|
-
|
314
|
-
batch_motor_time_indicators, dim=0
|
315
|
-
).to(torch.bool)
|
316
|
-
batch["motor_time_indicators"] = (
|
406
|
+
batch["motor_tte_masks"] = (
|
317
407
|
torch.concat(
|
318
408
|
[
|
319
|
-
|
409
|
+
torch.concat(motor_tte_masks, dim=0).to(torch.bool),
|
320
410
|
torch.full(
|
321
411
|
(padded_length, num_time_pieces, motor_tte_vocab_size),
|
322
412
|
False,
|
@@ -422,564 +512,118 @@ class CehrGptDataCollator:
|
|
422
512
|
|
423
513
|
return batch
|
424
514
|
|
425
|
-
def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
426
|
-
"""
|
427
|
-
Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
|
428
|
-
|
429
|
-
Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
|
430
|
-
to the occurrence of future motor-related events.
|
431
|
-
|
432
|
-
Args:
|
433
|
-
record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
|
434
|
-
This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
|
435
|
-
|
436
|
-
Returns:
|
437
|
-
Dict[str, Any]: The updated input record with added keys:
|
438
|
-
- "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
|
439
|
-
- "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
|
440
|
-
"""
|
441
|
-
input_ids = record["input_ids"]
|
442
|
-
sample_packing = getattr(self, "sample_packing", False)
|
443
|
-
|
444
|
-
if isinstance(input_ids, torch.Tensor):
|
445
|
-
input_ids = input_ids.detach().tolist()
|
446
|
-
|
447
|
-
# This potentially contains packed samples, we need to handle that
|
448
|
-
packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
449
|
-
pad_indices = []
|
450
|
-
if sample_packing:
|
451
|
-
# We start from the first index
|
452
|
-
for i in range(len(packed_concept_ids)):
|
453
|
-
if packed_concept_ids[i] == self.tokenizer.pad_token:
|
454
|
-
# If we encounter consecutive pads, we should break out of the loop
|
455
|
-
if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
|
456
|
-
break
|
457
|
-
pad_indices.append(i)
|
458
|
-
|
459
|
-
# If we did not find a pad, that means the whole sequence belongs to one sample
|
460
|
-
if len(pad_indices) == 0:
|
461
|
-
pad_indices.append(len(packed_concept_ids))
|
462
|
-
|
463
|
-
timepiece_time_to_event_vectors = []
|
464
|
-
timepiece_event_indicators = []
|
465
|
-
timepiece_indicators = []
|
466
|
-
time_to_event_to_includes = []
|
467
|
-
|
468
|
-
for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
|
469
|
-
concept_ids = packed_concept_ids[start_index:end_index]
|
470
|
-
if concept_ids[0] == self.tokenizer.pad_token:
|
471
|
-
concept_ids.pop(0)
|
472
|
-
time_to_event_vectors = []
|
473
|
-
global_event_indicators = []
|
474
|
-
|
475
|
-
# First collect TTE data in reverse chronological order
|
476
|
-
censor_times = []
|
477
|
-
time_to_event_data: List[Dict[str, int]] = []
|
478
|
-
time_to_event_dict: Dict[str, int] = {}
|
479
|
-
time_to_event_to_include: List[bool] = []
|
480
|
-
next_future_visit_concepts = set()
|
481
|
-
time_interval = 0
|
482
|
-
|
483
|
-
# Reverse walk through concept_ids to calculate TTE from each [VE] point
|
484
|
-
for concept_id in reversed(concept_ids):
|
485
|
-
if is_visit_end(concept_id):
|
486
|
-
# Update TTE for existing concepts, or add new ones seen in this visit
|
487
|
-
for existing_concept_id in list(time_to_event_dict.keys()):
|
488
|
-
if existing_concept_id in next_future_visit_concepts:
|
489
|
-
time_to_event_dict[existing_concept_id] = time_interval
|
490
|
-
else:
|
491
|
-
time_to_event_dict[existing_concept_id] += time_interval
|
492
|
-
|
493
|
-
for next_concept_id in next_future_visit_concepts:
|
494
|
-
if next_concept_id not in time_to_event_dict:
|
495
|
-
time_to_event_dict[next_concept_id] = time_interval
|
496
|
-
|
497
|
-
# If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
|
498
|
-
# previous visit
|
499
|
-
time_to_event_to_include.append(time_interval > 0)
|
500
|
-
time_to_event_data.append(copy.deepcopy(time_to_event_dict))
|
501
|
-
# Record the censor time at the end of the visit
|
502
|
-
if censor_times:
|
503
|
-
censor_times.append(censor_times[-1] + time_interval)
|
504
|
-
else:
|
505
|
-
censor_times.append(time_interval)
|
506
|
-
time_interval = 0
|
507
|
-
next_future_visit_concepts.clear()
|
508
|
-
|
509
|
-
elif is_att_token(concept_id):
|
510
|
-
time_interval += extract_time_interval_in_days(concept_id)
|
511
|
-
|
512
|
-
elif self.tokenizer.is_motor_time_to_event_code(concept_id):
|
513
|
-
next_future_visit_concepts.add(concept_id)
|
514
|
-
|
515
|
-
if len(time_to_event_data) == 0:
|
516
|
-
LOG.info(
|
517
|
-
"Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
|
518
|
-
"It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
|
519
|
-
len(concept_ids),
|
520
|
-
concept_ids[-10:],
|
521
|
-
)
|
522
|
-
continue
|
523
|
-
|
524
|
-
# Reverse back to chronological order for final labels
|
525
|
-
time_to_event_data.reverse()
|
526
|
-
censor_times.reverse()
|
527
|
-
time_to_event_to_include.reverse()
|
528
|
-
|
529
|
-
for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
|
530
|
-
time_to_event_vector = np.full(
|
531
|
-
self.tokenizer.motor_tte_vocab_size,
|
532
|
-
fill_value=censor_time,
|
533
|
-
dtype=np.int32,
|
534
|
-
)
|
535
|
-
event_indicator = np.zeros(
|
536
|
-
self.tokenizer.motor_tte_vocab_size,
|
537
|
-
dtype=np.int32,
|
538
|
-
)
|
539
|
-
visit_token_ids = [
|
540
|
-
self.tokenizer.get_motor_token_id(concept_id)
|
541
|
-
for concept_id in visit_tte_data.keys()
|
542
|
-
]
|
543
|
-
visit_tte_values = list(visit_tte_data.values())
|
544
|
-
|
545
|
-
time_to_event_vector[visit_token_ids] = visit_tte_values
|
546
|
-
event_indicator[visit_token_ids] = 1 # not censored (event occurred)
|
547
|
-
|
548
|
-
time_to_event_vectors.append(time_to_event_vector)
|
549
|
-
global_event_indicators.append(event_indicator)
|
550
|
-
|
551
|
-
time_to_event_vectors = np.asarray(time_to_event_vectors)
|
552
|
-
global_event_indicators = np.asarray(global_event_indicators).astype(bool)
|
553
|
-
n_visits = len(time_to_event_vectors)
|
554
|
-
|
555
|
-
timepiece_time_to_event_vector = np.full(
|
556
|
-
(
|
557
|
-
self.motor_num_time_pieces,
|
558
|
-
n_visits,
|
559
|
-
self.tokenizer.motor_tte_vocab_size,
|
560
|
-
),
|
561
|
-
fill_value=0,
|
562
|
-
dtype=np.int32,
|
563
|
-
)
|
564
|
-
timepiece_event_indicator = np.zeros(
|
565
|
-
(
|
566
|
-
self.motor_num_time_pieces,
|
567
|
-
n_visits,
|
568
|
-
self.tokenizer.motor_tte_vocab_size,
|
569
|
-
),
|
570
|
-
dtype=bool,
|
571
|
-
)
|
572
|
-
timepiece_indicator = np.zeros(
|
573
|
-
(
|
574
|
-
self.motor_num_time_pieces,
|
575
|
-
n_visits,
|
576
|
-
self.tokenizer.motor_tte_vocab_size,
|
577
|
-
),
|
578
|
-
dtype=bool,
|
579
|
-
)
|
580
|
-
|
581
|
-
# Putting the event time and censor time into the corresponding time bins
|
582
|
-
for bin_num in range(self.motor_num_time_pieces):
|
583
|
-
start = self.motor_time_interval * bin_num
|
584
|
-
end = self.motor_time_interval * (bin_num + 1)
|
585
|
-
time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
|
586
|
-
timepiece_time_to_event_vector[bin_num] = time_in_bin
|
587
|
-
event_indicator = (
|
588
|
-
global_event_indicators
|
589
|
-
& (start <= time_to_event_vectors)
|
590
|
-
& (time_to_event_vectors < end)
|
591
|
-
)
|
592
|
-
timepiece_event_indicator[bin_num] = event_indicator
|
593
|
-
timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
|
594
|
-
|
595
|
-
timepiece_time_to_event_vectors.append(
|
596
|
-
timepiece_time_to_event_vector.swapaxes(0, 1)
|
597
|
-
)
|
598
|
-
timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
|
599
|
-
timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
|
600
|
-
time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
|
601
|
-
|
602
|
-
record["time_to_event_vectors"] = np.concatenate(
|
603
|
-
timepiece_time_to_event_vectors, axis=0
|
604
|
-
)
|
605
|
-
record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
|
606
|
-
record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
|
607
|
-
record["time_to_event_to_include"] = np.concatenate(
|
608
|
-
time_to_event_to_includes, axis=0
|
609
|
-
)
|
610
|
-
return record
|
611
|
-
|
612
|
-
def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
613
|
-
|
614
|
-
if not self.shuffle_records:
|
615
|
-
return record
|
616
|
-
|
617
|
-
if "record_ranks" not in record:
|
618
|
-
return record
|
619
|
-
|
620
|
-
sorting_column = record["record_ranks"]
|
621
|
-
random_order = np.random.rand(len(sorting_column))
|
622
|
-
|
623
|
-
if self.include_values:
|
624
|
-
iterator = zip(
|
625
|
-
sorting_column,
|
626
|
-
random_order,
|
627
|
-
record["input_ids"],
|
628
|
-
record["value_indicators"],
|
629
|
-
record["values"],
|
630
|
-
)
|
631
|
-
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
632
|
-
_, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
|
633
|
-
*list(sorted_list)
|
634
|
-
)
|
635
|
-
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
636
|
-
record["value_indicators"] = self._convert_to_tensor(
|
637
|
-
sorted_value_indicators
|
638
|
-
)
|
639
|
-
record["values"] = self._convert_to_tensor(sorted_values)
|
640
|
-
else:
|
641
|
-
iterator = zip(sorting_column, random_order, record["input_ids"])
|
642
|
-
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
643
|
-
_, _, sorted_input_ids = zip(*list(sorted_list))
|
644
|
-
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
645
|
-
return record
|
646
|
-
|
647
|
-
def generate_start_end_index(
|
648
|
-
self,
|
649
|
-
record: Dict[str, Any],
|
650
|
-
sample_packing: bool,
|
651
|
-
max_length_allowed: Optional[int] = None,
|
652
|
-
) -> Dict[str, Any]:
|
653
|
-
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
654
|
-
# concept_ids will be used to for time to event predictions and identifying the visit starts
|
655
|
-
max_length_allowed = (
|
656
|
-
self.max_length if max_length_allowed is None else max_length_allowed
|
657
|
-
)
|
658
|
-
input_ids = record["input_ids"]
|
659
|
-
if isinstance(input_ids, torch.Tensor):
|
660
|
-
input_ids = input_ids.detach().tolist()
|
661
|
-
concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
662
|
-
seq_length = len(record["input_ids"])
|
663
|
-
|
664
|
-
# Subtract one for the [END] token when sample_packing is not enabled
|
665
|
-
new_max_length = (
|
666
|
-
max_length_allowed if sample_packing else max_length_allowed - 1
|
667
|
-
)
|
668
|
-
|
669
|
-
if self.include_ttv_prediction:
|
670
|
-
record["time_to_visits"] = torch.concat(
|
671
|
-
[self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
|
672
|
-
)
|
673
|
-
|
674
|
-
# If linear token exists, we will use it, otherwise we default to the OOV token
|
675
|
-
linear_token_id = (
|
676
|
-
self.tokenizer.linear_token_id
|
677
|
-
if self.tokenizer.linear_token_id
|
678
|
-
else self.tokenizer.oov_token_id
|
679
|
-
)
|
680
|
-
eos_token = (
|
681
|
-
linear_token_id
|
682
|
-
if self.add_linear_prob_token
|
683
|
-
else self.tokenizer.end_token_id
|
684
|
-
)
|
685
|
-
|
686
|
-
# Return the record directly if the actual sequence length is less than the max sequence
|
687
|
-
if seq_length <= new_max_length:
|
688
|
-
if not sample_packing:
|
689
|
-
record["input_ids"] = torch.concat(
|
690
|
-
[
|
691
|
-
self._convert_to_tensor(record["input_ids"]),
|
692
|
-
self._convert_to_tensor([eos_token]),
|
693
|
-
]
|
694
|
-
)
|
695
|
-
if self.include_values:
|
696
|
-
record["value_indicators"] = torch.concat(
|
697
|
-
[
|
698
|
-
self._convert_to_tensor(record["value_indicators"]),
|
699
|
-
self._convert_to_tensor([False]),
|
700
|
-
]
|
701
|
-
).to(torch.bool)
|
702
|
-
record["values"] = torch.concat(
|
703
|
-
[
|
704
|
-
self._convert_to_tensor(record["values"]),
|
705
|
-
self._convert_to_tensor(
|
706
|
-
[self.tokenizer.pad_value_token_id]
|
707
|
-
),
|
708
|
-
]
|
709
|
-
)
|
710
|
-
if self.include_ttv_prediction:
|
711
|
-
record["time_to_visits"] = torch.concat(
|
712
|
-
[
|
713
|
-
record["time_to_visits"],
|
714
|
-
self._convert_to_tensor([-100.0]),
|
715
|
-
]
|
716
|
-
)
|
717
|
-
return record
|
718
|
-
|
719
|
-
if self.pretraining:
|
720
|
-
# There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
|
721
|
-
# prompt depending on the new starting point
|
722
|
-
if random.random() < 0.5 and not sample_packing:
|
723
|
-
start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
|
724
|
-
concept_ids, new_max_length
|
725
|
-
)
|
726
|
-
if start_index != end_index:
|
727
|
-
record["input_ids"] = self._convert_to_tensor(
|
728
|
-
record["input_ids"][start_index : end_index + 1]
|
729
|
-
)
|
730
|
-
if self.include_values:
|
731
|
-
record["value_indicators"] = self._convert_to_tensor(
|
732
|
-
record["value_indicators"][start_index : end_index + 1]
|
733
|
-
).to(torch.bool)
|
734
|
-
record["values"] = self._convert_to_tensor(
|
735
|
-
record["values"][start_index : end_index + 1]
|
736
|
-
)
|
737
|
-
if self.include_ttv_prediction:
|
738
|
-
record["time_to_visits"] = self._convert_to_tensor(
|
739
|
-
self._convert_time_to_event(
|
740
|
-
concept_ids[start_index : end_index + 1]
|
741
|
-
)
|
742
|
-
)
|
743
|
-
return record
|
744
|
-
|
745
|
-
# The default employs a right truncation strategy, where the demographic prompt is reserved
|
746
|
-
end_index = new_max_length
|
747
|
-
for i in reversed(list(range(0, end_index))):
|
748
|
-
current_token = record["input_ids"][i]
|
749
|
-
if current_token == self.ve_token_id:
|
750
|
-
# Plus one because slicing is right exclusive
|
751
|
-
end_index = i + 1
|
752
|
-
break
|
753
|
-
|
754
|
-
record["input_ids"] = record["input_ids"][0:end_index]
|
755
|
-
|
756
|
-
# We want to make sure we take the subset of attention_mask in sample packing if this field is available
|
757
|
-
if sample_packing and "attention_mask" in record:
|
758
|
-
record["attention_mask"] = record["attention_mask"][0:end_index]
|
759
|
-
|
760
|
-
if sample_packing and "position_ids" in record:
|
761
|
-
record["position_ids"] = record["position_ids"][0:end_index]
|
762
|
-
|
763
|
-
if self.include_values:
|
764
|
-
record["value_indicators"] = self._convert_to_tensor(
|
765
|
-
record["value_indicators"][0:end_index]
|
766
|
-
).to(torch.bool)
|
767
|
-
record["values"] = self._convert_to_tensor(
|
768
|
-
record["values"][0:end_index]
|
769
|
-
)
|
770
|
-
if self.include_ttv_prediction:
|
771
|
-
record["time_to_visits"] = self._convert_to_tensor(
|
772
|
-
self._convert_time_to_event(concept_ids[0:end_index])
|
773
|
-
)
|
774
|
-
return record
|
775
|
-
else:
|
776
|
-
if self.include_demographics and not sample_packing:
|
777
|
-
# We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
|
778
|
-
demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
|
779
|
-
concept_ids
|
780
|
-
)
|
781
|
-
for token_index, demographic_prompt in demographic_prompts_at_visits:
|
782
|
-
if (
|
783
|
-
seq_length - token_index
|
784
|
-
<= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
|
785
|
-
):
|
786
|
-
demographic_tokens = self.tokenizer.encode(demographic_prompt)
|
787
|
-
record["input_ids"] = torch.concat(
|
788
|
-
[
|
789
|
-
self._convert_to_tensor(demographic_tokens),
|
790
|
-
self._convert_to_tensor(
|
791
|
-
record["input_ids"][token_index:seq_length]
|
792
|
-
),
|
793
|
-
]
|
794
|
-
)
|
795
|
-
if self.include_values:
|
796
|
-
record["value_indicators"] = torch.concat(
|
797
|
-
[
|
798
|
-
torch.zeros(
|
799
|
-
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
800
|
-
).to(torch.bool),
|
801
|
-
self._convert_to_tensor(
|
802
|
-
record["value_indicators"][
|
803
|
-
token_index:seq_length
|
804
|
-
]
|
805
|
-
),
|
806
|
-
]
|
807
|
-
)
|
808
|
-
record["values"] = torch.concat(
|
809
|
-
[
|
810
|
-
torch.zeros(
|
811
|
-
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
812
|
-
)
|
813
|
-
.to(torch.int32)
|
814
|
-
.fill_(self.tokenizer.pad_value_token_id),
|
815
|
-
self._convert_to_tensor(
|
816
|
-
record["values"][token_index:seq_length]
|
817
|
-
),
|
818
|
-
]
|
819
|
-
)
|
820
|
-
if self.include_ttv_prediction:
|
821
|
-
record["time_to_visits"] = torch.concat(
|
822
|
-
[
|
823
|
-
torch.zeros(
|
824
|
-
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
825
|
-
)
|
826
|
-
.to(torch.float32)
|
827
|
-
.fill_(-100.0),
|
828
|
-
record["time_to_visits"][token_index:seq_length],
|
829
|
-
]
|
830
|
-
)
|
831
|
-
break
|
832
|
-
else:
|
833
|
-
start_index = seq_length - new_max_length
|
834
|
-
end_index = seq_length
|
835
|
-
for i in range(start_index, end_index):
|
836
|
-
current_token = record["input_ids"][i]
|
837
|
-
if current_token == self.vs_token_id:
|
838
|
-
record["input_ids"] = record["input_ids"][i:end_index]
|
839
|
-
if sample_packing and "attention_mask" in record:
|
840
|
-
record["attention_mask"] = record["attention_mask"][
|
841
|
-
i:end_index
|
842
|
-
]
|
843
|
-
if sample_packing and "position_ids" in record:
|
844
|
-
record["position_ids"] = record["position_ids"][i:end_index]
|
845
|
-
if self.include_values:
|
846
|
-
record["value_indicators"] = record["value_indicators"][
|
847
|
-
i:end_index
|
848
|
-
]
|
849
|
-
record["values"] = record["values"][i:end_index]
|
850
|
-
if self.include_ttv_prediction:
|
851
|
-
record["time_to_visits"] = record["time_to_visits"][
|
852
|
-
i:end_index
|
853
|
-
]
|
854
|
-
break
|
855
|
-
|
856
|
-
# This could happen when the last visit contains more than new_max_length number of tokens
|
857
|
-
# We simply take the last new_max_length number of tokens from the patient sequence
|
858
|
-
if len(record["input_ids"]) > new_max_length:
|
859
|
-
record["input_ids"] = record["input_ids"][-new_max_length:]
|
860
|
-
if sample_packing and "attention_mask" in record:
|
861
|
-
record["attention_mask"] = record["attention_mask"][
|
862
|
-
-new_max_length:
|
863
|
-
]
|
864
|
-
if sample_packing and "position_ids" in record:
|
865
|
-
record["position_ids"] = record["position_ids"][-new_max_length:]
|
866
|
-
if self.include_values:
|
867
|
-
record["value_indicators"] = record["value_indicators"][
|
868
|
-
-new_max_length:
|
869
|
-
]
|
870
|
-
record["values"] = record["values"][-new_max_length:]
|
871
|
-
if self.include_ttv_prediction:
|
872
|
-
record["time_to_visits"] = record["time_to_visits"][
|
873
|
-
-new_max_length:
|
874
|
-
]
|
875
|
-
|
876
|
-
if not sample_packing:
|
877
|
-
# Finally we add the end token to the end of the sequence
|
878
|
-
record["input_ids"] = torch.concat(
|
879
|
-
[
|
880
|
-
self._convert_to_tensor(record["input_ids"]),
|
881
|
-
self._convert_to_tensor([eos_token]),
|
882
|
-
]
|
883
|
-
)
|
884
|
-
if self.include_values:
|
885
|
-
record["value_indicators"] = torch.concat(
|
886
|
-
[
|
887
|
-
self._convert_to_tensor(record["value_indicators"]),
|
888
|
-
self._convert_to_tensor([False]),
|
889
|
-
]
|
890
|
-
).to(torch.bool)
|
891
|
-
record["values"] = torch.concat(
|
892
|
-
[
|
893
|
-
self._convert_to_tensor(record["values"]),
|
894
|
-
self._convert_to_tensor(
|
895
|
-
[self.tokenizer.pad_value_token_id]
|
896
|
-
),
|
897
|
-
]
|
898
|
-
)
|
899
|
-
if self.include_ttv_prediction:
|
900
|
-
record["time_to_visits"] = torch.concat(
|
901
|
-
[
|
902
|
-
record["time_to_visits"],
|
903
|
-
self._convert_to_tensor([-100.0]),
|
904
|
-
]
|
905
|
-
)
|
906
|
-
return record
|
907
|
-
|
908
515
|
|
909
516
|
class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
910
517
|
def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
|
911
518
|
self.max_tokens_per_batch = max_tokens
|
912
519
|
self.max_position_embeddings = max_position_embeddings
|
913
520
|
self.sample_packing = True
|
914
|
-
self.add_end_token_in_sample_packing = kwargs.pop(
|
915
|
-
"add_end_token_in_sample_packing", False
|
916
|
-
)
|
917
521
|
super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
|
522
|
+
self.cehrgpt_data_processor.max_length = self.max_position_embeddings
|
918
523
|
|
919
524
|
def __call__(self, examples):
|
920
525
|
current_input_ids = []
|
921
526
|
current_attention_mask = []
|
922
|
-
|
527
|
+
current_ages = []
|
528
|
+
current_epoch_times = []
|
923
529
|
current_value_indicators = []
|
924
530
|
current_values = []
|
925
531
|
|
532
|
+
# MOTOR inputs
|
533
|
+
current_motor_censor_times = []
|
534
|
+
current_motor_row_indices = []
|
535
|
+
current_motor_col_indices = []
|
536
|
+
current_motor_values = []
|
537
|
+
current_motor_tte_task_indicators = []
|
538
|
+
|
926
539
|
# Demographics
|
927
540
|
current_person_ids = []
|
928
541
|
current_index_dates = []
|
929
542
|
|
930
543
|
# Binary classification inputs
|
931
|
-
|
544
|
+
current_prediction_ages = []
|
932
545
|
current_labels = []
|
933
546
|
|
934
547
|
for idx, example in enumerate(examples):
|
935
|
-
|
936
|
-
# We only add an end token if the patient sequence could fit in the entire context window
|
937
|
-
add_end_token = (
|
938
|
-
len(example["input_ids"]) <= self.max_position_embeddings
|
939
|
-
and self.add_end_token_in_sample_packing
|
940
|
-
)
|
941
|
-
# If the sample length exceeds the model's capacity, truncate this example
|
942
|
-
if len(example["input_ids"]) > self.max_position_embeddings:
|
943
|
-
example = self.generate_start_end_index(
|
944
|
-
example, False, self.max_position_embeddings
|
945
|
-
)
|
946
|
-
|
947
|
-
add_eos_token = add_end_token | self.add_linear_prob_token
|
948
|
-
additional_tokens = []
|
949
|
-
if add_end_token:
|
950
|
-
additional_tokens.append(self.tokenizer.end_token_id)
|
951
|
-
elif self.add_linear_prob_token:
|
952
|
-
# Backward compatible
|
953
|
-
linear_prob_token_id = (
|
954
|
-
self.tokenizer.linear_token_id
|
955
|
-
if self.tokenizer.linear_token_id is not None
|
956
|
-
else self.tokenizer.oov_token_id
|
957
|
-
)
|
958
|
-
additional_tokens.append(linear_prob_token_id)
|
959
|
-
additional_tokens.append(self.tokenizer.pad_token_id)
|
548
|
+
example = self.cehrgpt_data_processor.transform(example)
|
960
549
|
input_ids = example["input_ids"]
|
961
550
|
# We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
|
962
551
|
# If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
|
963
552
|
# we only add [PAD] token to the end of the sequence because it's not finished
|
964
|
-
current_input_ids.extend(list(input_ids) +
|
965
|
-
current_attention_mask.extend(
|
966
|
-
|
553
|
+
current_input_ids.extend(list(input_ids) + [self.tokenizer.pad_token_id])
|
554
|
+
current_attention_mask.extend(np.ones_like(input_ids).tolist() + [0])
|
555
|
+
|
556
|
+
ages = (
|
557
|
+
example["ages"].tolist()
|
558
|
+
if isinstance(example["ages"], torch.Tensor)
|
559
|
+
else list(example["ages"])
|
967
560
|
)
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
)
|
561
|
+
current_ages.extend(ages + [max(ages)])
|
562
|
+
|
563
|
+
epoch_times = (
|
564
|
+
example["epoch_times"].tolist()
|
565
|
+
if isinstance(example["epoch_times"], torch.Tensor)
|
566
|
+
else list(example["epoch_times"])
|
975
567
|
)
|
568
|
+
current_epoch_times.extend(epoch_times + [max(epoch_times)])
|
569
|
+
|
976
570
|
if self.include_values:
|
977
571
|
current_value_indicators.extend(
|
978
|
-
|
572
|
+
(
|
573
|
+
example["value_indicators"].tolist()
|
574
|
+
if isinstance(example["value_indicators"], torch.Tensor)
|
575
|
+
else list(example["value_indicators"])
|
576
|
+
)
|
577
|
+
+ [False]
|
979
578
|
)
|
980
579
|
current_values.extend(
|
981
|
-
|
982
|
-
|
580
|
+
(
|
581
|
+
example["values"].tolist()
|
582
|
+
if isinstance(example["values"], torch.Tensor)
|
583
|
+
else list(example["values"])
|
584
|
+
)
|
585
|
+
+ [self.tokenizer.pad_value_token_id]
|
586
|
+
)
|
587
|
+
|
588
|
+
if self.include_motor_time_to_event:
|
589
|
+
current_max_motor_row_index = len(np.unique(current_motor_row_indices))
|
590
|
+
motor_row_indices = (
|
591
|
+
example["motor_row_indices"].tolist()
|
592
|
+
if isinstance(example["motor_row_indices"], torch.Tensor)
|
593
|
+
else list(example["motor_row_indices"])
|
594
|
+
)
|
595
|
+
current_motor_row_indices.extend(
|
596
|
+
list(
|
597
|
+
map(
|
598
|
+
lambda offset: offset + current_max_motor_row_index,
|
599
|
+
motor_row_indices,
|
600
|
+
)
|
601
|
+
)
|
602
|
+
)
|
603
|
+
current_motor_col_indices.extend(
|
604
|
+
example["motor_col_indices"].tolist()
|
605
|
+
if isinstance(example["motor_col_indices"], torch.Tensor)
|
606
|
+
else list(example["motor_col_indices"])
|
607
|
+
)
|
608
|
+
current_motor_values.extend(
|
609
|
+
example["motor_values"].tolist()
|
610
|
+
if isinstance(example["motor_values"], torch.Tensor)
|
611
|
+
else list(example["motor_values"])
|
612
|
+
)
|
613
|
+
current_motor_censor_times.extend(
|
614
|
+
example["motor_censor_times"].tolist()
|
615
|
+
if isinstance(example["motor_censor_times"], torch.Tensor)
|
616
|
+
else list(example["motor_censor_times"])
|
617
|
+
)
|
618
|
+
current_motor_tte_task_indicators.extend(
|
619
|
+
(
|
620
|
+
example["motor_tte_task_indicators"].tolist()
|
621
|
+
if isinstance(
|
622
|
+
example["motor_tte_task_indicators"], torch.Tensor
|
623
|
+
)
|
624
|
+
else list(example["motor_tte_task_indicators"])
|
625
|
+
)
|
626
|
+
+ [False]
|
983
627
|
)
|
984
628
|
|
985
629
|
if "person_id" in example:
|
@@ -989,7 +633,7 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
989
633
|
current_index_dates.append(example["index_date"])
|
990
634
|
|
991
635
|
if "age_at_index" in example:
|
992
|
-
|
636
|
+
current_prediction_ages.append(example["age_at_index"])
|
993
637
|
|
994
638
|
if "classifier_label" in example:
|
995
639
|
current_labels.append(example["classifier_label"])
|
@@ -1001,20 +645,33 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
1001
645
|
packed_example = {
|
1002
646
|
"input_ids": current_input_ids,
|
1003
647
|
"attention_mask": current_attention_mask,
|
1004
|
-
"
|
648
|
+
"ages": current_ages,
|
649
|
+
"epoch_times": current_epoch_times,
|
1005
650
|
}
|
651
|
+
|
1006
652
|
if self.include_values:
|
1007
|
-
packed_example.update(
|
1008
|
-
|
653
|
+
packed_example.update(
|
654
|
+
{"value_indicators": current_value_indicators, "values": current_values}
|
655
|
+
)
|
656
|
+
if self.include_motor_time_to_event:
|
657
|
+
packed_example.update(
|
658
|
+
{
|
659
|
+
"motor_censor_times": current_motor_censor_times,
|
660
|
+
"motor_row_indices": current_motor_row_indices,
|
661
|
+
"motor_col_indices": current_motor_col_indices,
|
662
|
+
"motor_values": current_motor_values,
|
663
|
+
"motor_tte_task_indicators": current_motor_tte_task_indicators,
|
664
|
+
}
|
665
|
+
)
|
1009
666
|
|
1010
667
|
if current_labels:
|
1011
668
|
packed_example.update(
|
1012
669
|
{
|
1013
670
|
"person_id": current_person_ids,
|
1014
671
|
"index_date": current_index_dates,
|
1015
|
-
"age_at_index":
|
672
|
+
"age_at_index": current_prediction_ages,
|
1016
673
|
"classifier_label": current_labels,
|
1017
674
|
}
|
1018
675
|
)
|
1019
|
-
|
676
|
+
# print(f"Packing examples took {time.time() - start} seconds")
|
1020
677
|
return super().__call__([packed_example])
|