cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,13 @@
1
- import copy
2
- import random
3
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Dict, List
4
2
 
5
3
  import numpy as np
6
4
  import torch
7
5
  from torch.nn.utils.rnn import pad_sequence
8
6
  from transformers.utils import logging
9
7
 
10
- from cehrgpt.gpt_utils import (
11
- DEMOGRAPHIC_PROMPT_SIZE,
12
- collect_demographic_prompts_at_visits,
13
- extract_time_interval_in_days,
14
- extract_time_interval_in_hours,
15
- is_att_token,
16
- is_inpatient_att_token,
17
- is_inpatient_hour_token,
18
- is_visit_end,
19
- random_slice_gpt_sequence,
20
- )
8
+ from cehrgpt.data.cehrgpt_data_processor import CehrGptDataProcessor
21
9
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
22
10
 
23
- TIME_TO_EVENT_MAX_TIME = 3650
24
- INPATIENT_STAY_DURATION_LIMIT = 30
25
11
  LOG = logging.get_logger("transformers")
26
12
 
27
13
 
@@ -30,13 +16,14 @@ class CehrGptDataCollator:
30
16
  self,
31
17
  tokenizer: CehrGptTokenizer,
32
18
  max_length: int,
33
- shuffle_records: bool = False,
34
19
  include_values: bool = False,
20
+ shuffle_records: bool = False,
35
21
  include_ttv_prediction: bool = False,
36
22
  use_sub_time_tokenization: bool = False,
37
23
  include_motor_time_to_event: bool = False,
38
24
  motor_tte_vocab_size: int = 0,
39
25
  motor_num_time_pieces: int = 8,
26
+ motor_sampling_probability: float = 0.5,
40
27
  pretraining: bool = True,
41
28
  include_demographics: bool = False,
42
29
  add_linear_prob_token: bool = False,
@@ -47,13 +34,12 @@ class CehrGptDataCollator:
47
34
  self.vs_token_id = tokenizer.vs_token_id
48
35
  self.ve_token_id = tokenizer.ve_token_id
49
36
 
50
- self.shuffle_records = shuffle_records
51
37
  self.include_values = include_values
52
38
  self.include_ttv_prediction = include_ttv_prediction
53
39
  self.use_sub_time_tokenization = use_sub_time_tokenization
54
40
  self.pretraining = pretraining
55
41
  self.include_demographics = include_demographics
56
- self.add_linear_prob_token = add_linear_prob_token
42
+ self.motor_code_cache: Dict[str, List[str]] = dict()
57
43
 
58
44
  # MOTOR TTE configuration
59
45
  if include_motor_time_to_event:
@@ -66,8 +52,14 @@ class CehrGptDataCollator:
66
52
  self.include_motor_time_to_event = include_motor_time_to_event
67
53
  self.motor_tte_vocab_size = motor_tte_vocab_size
68
54
  self.motor_num_time_pieces = motor_num_time_pieces
69
- self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
70
-
55
+ self.motor_time_bins = (
56
+ self.tokenizer.get_motor_time_bins(motor_num_time_pieces)
57
+ if self.include_motor_time_to_event
58
+ else []
59
+ )
60
+ # Convert the time bins to seconds
61
+ self.motor_time_bins = [time_bin * 86400 for time_bin in self.motor_time_bins]
62
+ LOG.info("self.motor_time_bins: %s", self.motor_time_bins)
71
63
  if self.use_sub_time_tokenization:
72
64
  token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
73
65
  if not token_to_time_token_mapping:
@@ -83,6 +75,18 @@ class CehrGptDataCollator:
83
75
  list(token_to_time_token_mapping.values()), dtype=torch.int64
84
76
  )
85
77
 
78
+ self.cehrgpt_data_processor = CehrGptDataProcessor(
79
+ tokenizer=tokenizer,
80
+ max_length=self.max_length,
81
+ shuffle_records=shuffle_records,
82
+ include_ttv_prediction=include_ttv_prediction,
83
+ include_values=include_values,
84
+ include_motor_time_to_event=include_motor_time_to_event,
85
+ motor_sampling_probability=motor_sampling_probability,
86
+ pretraining=pretraining,
87
+ add_linear_prob_token=add_linear_prob_token,
88
+ )
89
+
86
90
  def _try_reverse_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
87
91
  if not self.pretraining:
88
92
  return torch.flip(tensor, dims=[-1])
@@ -95,30 +99,120 @@ class CehrGptDataCollator:
95
99
  else:
96
100
  return torch.tensor(features)
97
101
 
98
- @staticmethod
99
- def _convert_time_to_event(concept_ids):
100
- def default_value(c):
101
- try:
102
- if is_att_token(c):
103
- time_to_visit = extract_time_interval_in_days(c)
104
- if (
105
- is_inpatient_att_token(c)
106
- and time_to_visit > INPATIENT_STAY_DURATION_LIMIT
107
- ):
108
- return -100
109
- return time_to_visit
110
- elif is_inpatient_hour_token(c):
111
- return extract_time_interval_in_hours(c) / 24
112
- return -100
113
- except ValueError:
114
- return -100
115
-
116
- return [float(default_value(_)) for _ in concept_ids]
102
+ def create_time_to_event_tensors_ultra_optimized(
103
+ self, record: Dict[str, Any]
104
+ ) -> Dict[str, Any]:
105
+ """Ultra-optimized version using advanced vectorization techniques."""
106
+ motor_row_indices = record["motor_row_indices"]
107
+ motor_col_indices = record["motor_col_indices"]
108
+ motor_values = record["motor_values"]
109
+ motor_censor_times = record["motor_censor_times"]
110
+
111
+ if len(motor_row_indices) == 0:
112
+ # Handle empty case - use tuples for better performance
113
+ empty_shape = (
114
+ 0,
115
+ self.motor_num_time_pieces,
116
+ self.tokenizer.motor_tte_vocab_size,
117
+ )
118
+ record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
119
+ record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
120
+ record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
121
+ return record
122
+
123
+ # Convert to numpy arrays once and get dimensions
124
+ motor_row_indices = np.asarray(motor_row_indices, dtype=np.int32)
125
+ motor_col_indices = np.asarray(motor_col_indices, dtype=np.int32)
126
+ motor_values = np.asarray(motor_values, dtype=np.float32)
127
+ motor_censor_times = np.asarray(motor_censor_times, dtype=np.float32)
128
+
129
+ n_tte_predictions = len(motor_censor_times) # More direct than unique()
130
+ vocab_size = self.tokenizer.motor_tte_vocab_size
131
+ n_time_pieces = self.motor_num_time_pieces
132
+
133
+ # Create time_vectors more efficiently without broadcasting copy
134
+ time_vectors = np.tile(
135
+ motor_censor_times[:, np.newaxis], (1, vocab_size)
136
+ ).astype(np.float32)
137
+ event_indicators = np.zeros((n_tte_predictions, vocab_size), dtype=bool)
138
+
139
+ # Vectorized assignment (already optimal)
140
+ time_vectors[motor_row_indices, motor_col_indices] = motor_values
141
+ event_indicators[motor_row_indices, motor_col_indices] = True
142
+
143
+ # Early return if no predictions
144
+ if n_tte_predictions == 0:
145
+ empty_shape = (0, n_time_pieces, vocab_size)
146
+ record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
147
+ record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
148
+ record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
149
+ return record
150
+
151
+ # Cache motor_time_bins as numpy array to avoid repeated conversion
152
+ if not hasattr(self, "_motor_time_bins_array"):
153
+ self._motor_time_bins_array = np.asarray(
154
+ self.motor_time_bins, dtype=np.float32
155
+ )
156
+
157
+ motor_time_bins = self._motor_time_bins_array
158
+ start_times = motor_time_bins[:-1]
159
+ end_times = motor_time_bins[1:]
160
+ bin_widths = end_times - start_times # Pre-compute bin widths
161
+
162
+ # ELIMINATED TRANSPOSE: Compute directly in target shape (n_pred, n_bins, vocab)
163
+ # Reshape for broadcasting in target order
164
+ time_vectors_3d = time_vectors[:, np.newaxis, :] # (n_pred, 1, vocab)
165
+ event_indicators_3d = event_indicators[:, np.newaxis, :] # (n_pred, 1, vocab)
166
+
167
+ # Broadcast time bins to match target shape
168
+ start_times_broadcast = start_times[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
169
+ bin_widths_broadcast = bin_widths[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
170
+
171
+ # Compute directly in target shape (n_pred, n_bins, vocab)
172
+ time_diff = time_vectors_3d - start_times_broadcast
173
+ time_in_bin = np.clip(time_diff, 0, bin_widths_broadcast)
174
+
175
+ # Optimized mask computation
176
+ mask = time_in_bin > 0
177
+
178
+ # More efficient log computation with better constant
179
+ log_constant = 1e-8 # Better numerical stability than 1e-10
180
+ time_in_bin_log = np.where(
181
+ mask, np.log2(np.maximum(time_in_bin, log_constant)), -np.inf
182
+ )
183
+
184
+ # Event indicator computation in target shape
185
+ end_times_broadcast = motor_time_bins[1:][np.newaxis, :, np.newaxis]
186
+ time_in_range = (time_vectors_3d >= start_times_broadcast) & (
187
+ time_vectors_3d < end_times_broadcast
188
+ )
189
+ event_in_bin = event_indicators_3d & time_in_range
190
+
191
+ # Combined mask computation
192
+ final_mask = mask | event_in_bin
193
+
194
+ # Direct assignment - NO TRANSPOSE NEEDED!
195
+ record["motor_tte_times"] = time_in_bin_log
196
+ record["motor_tte_event_indicators"] = event_in_bin
197
+ record["motor_tte_masks"] = final_mask
198
+
199
+ # Validation (keep as is - important for correctness)
200
+ assert (
201
+ sum(record["motor_tte_task_indicators"]) == n_tte_predictions
202
+ ), f'sum(record["motor_tte_task_indicators"]) == n_tte_predictions must be true'
203
+
204
+ # Clean up input data
205
+ del record["motor_row_indices"]
206
+ del record["motor_col_indices"]
207
+ del record["motor_values"]
208
+
209
+ return record
117
210
 
118
211
  def __call__(self, examples):
119
- sample_packing = getattr(self, "sample_packing", False)
120
- examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
121
- examples = [self.random_sort(_) for _ in examples]
212
+
213
+ if not getattr(self, "sample_packing", False):
214
+ examples = [self.cehrgpt_data_processor.transform(_) for _ in examples]
215
+
122
216
  batch = {}
123
217
 
124
218
  # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
@@ -162,21 +256,31 @@ class CehrGptDataCollator:
162
256
  f"batch['input_ids']: {batch['input_ids']} "
163
257
  )
164
258
 
165
- if "position_ids" in examples[0]:
166
- batch_position_ids = [
167
- self._try_reverse_tensor(
168
- self._convert_to_tensor(example["position_ids"])
169
- )
170
- for example in examples
171
- ]
172
- # Pad sequences to the max length in the batch
173
- batch["position_ids"] = self._try_reverse_tensor(
174
- pad_sequence(
175
- batch_position_ids,
176
- batch_first=True,
177
- padding_value=0,
178
- ).to(torch.int64)
179
- )
259
+ batch_ages = [
260
+ self._try_reverse_tensor(self._convert_to_tensor(example["ages"]))
261
+ for example in examples
262
+ ]
263
+ # Pad sequences to the max length in the batch
264
+ batch["ages"] = self._try_reverse_tensor(
265
+ pad_sequence(
266
+ batch_ages,
267
+ batch_first=True,
268
+ padding_value=0,
269
+ ).to(torch.int64)
270
+ )
271
+
272
+ batch_epoch_times = [
273
+ self._try_reverse_tensor(self._convert_to_tensor(example["epoch_times"]))
274
+ for example in examples
275
+ ]
276
+ # Pad sequences to the max length in the batch
277
+ batch["epoch_times"] = self._try_reverse_tensor(
278
+ pad_sequence(
279
+ batch_epoch_times,
280
+ batch_first=True,
281
+ padding_value=0,
282
+ ).to(torch.float32)
283
+ )
180
284
 
181
285
  if self.pretraining:
182
286
  batch["labels"] = torch.where(
@@ -217,54 +321,51 @@ class CehrGptDataCollator:
217
321
 
218
322
  if self.include_motor_time_to_event:
219
323
  examples_with_motor_tte = [
220
- self.create_time_to_event_labels(_) for _ in examples
324
+ self.create_time_to_event_tensors_ultra_optimized(_) for _ in examples
221
325
  ]
222
- batch_motor_time_to_event_vectors = [
326
+ # print(f"Creating MOTOR TTE tensors took {time.time() - start} seconds")
327
+ motor_tte_times = [
223
328
  self._try_reverse_tensor(
224
- self._convert_to_tensor(example["time_to_event_vectors"])
329
+ self._convert_to_tensor(example["motor_tte_times"])
225
330
  )
226
331
  for example in examples_with_motor_tte
227
332
  ]
228
- batch_motor_event_indicators = [
333
+ motor_tte_event_indicators = [
229
334
  self._try_reverse_tensor(
230
- self._convert_to_tensor(example["event_indicators"])
335
+ self._convert_to_tensor(example["motor_tte_event_indicators"])
231
336
  )
232
337
  for example in examples_with_motor_tte
233
338
  ]
234
- batch_motor_time_to_event_to_include = [
339
+ motor_tte_task_indicators = [
235
340
  self._try_reverse_tensor(
236
- self._convert_to_tensor(example["time_to_event_to_include"])
341
+ self._convert_to_tensor(example["motor_tte_task_indicators"])
237
342
  )
238
343
  for example in examples_with_motor_tte
239
344
  ]
240
- batch_motor_time_indicators = [
345
+ motor_tte_masks = [
241
346
  self._try_reverse_tensor(
242
- self._convert_to_tensor(example["time_indicators"])
347
+ self._convert_to_tensor(example["motor_tte_masks"])
243
348
  )
244
349
  for example in examples_with_motor_tte
245
350
  ]
246
351
 
247
- batch_motor_time_to_event_vectors = torch.concat(
248
- batch_motor_time_to_event_vectors, dim=0
249
- ).to(torch.float32)
352
+ motor_tte_times = torch.concat(motor_tte_times, dim=0).to(torch.float32)
250
353
 
251
354
  # If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
252
355
  # we only create the labels when any example has more than one visit
253
- if batch_motor_time_to_event_vectors.dim() <= 1:
356
+ if motor_tte_times.dim() <= 1:
254
357
  LOG.warning(
255
358
  "There are no MOTOR TTE labels generated for this batch "
256
359
  "because every example in this batch only contains one visit."
257
360
  )
258
361
  else:
259
362
  batch_size = len(examples)
260
- length, num_time_pieces, motor_tte_vocab_size = (
261
- batch_motor_time_to_event_vectors.shape
262
- )
363
+ length, num_time_pieces, motor_tte_vocab_size = motor_tte_times.shape
263
364
  padded_length = batch_size - length % batch_size
264
- batch["motor_time_to_event_vectors"] = (
365
+ batch["motor_tte_times"] = (
265
366
  torch.concat(
266
367
  [
267
- batch_motor_time_to_event_vectors,
368
+ motor_tte_times,
268
369
  torch.full(
269
370
  (padded_length, num_time_pieces, motor_tte_vocab_size),
270
371
  0.0,
@@ -277,13 +378,12 @@ class CehrGptDataCollator:
277
378
  )
278
379
 
279
380
  # Motor event indicators that indicate there is an event occurred in this time interval
280
- batch_motor_event_indicators = torch.concat(
281
- batch_motor_event_indicators, dim=0
282
- ).to(torch.bool)
283
- batch["motor_event_indicators"] = (
381
+ batch["motor_tte_event_indicators"] = (
284
382
  torch.concat(
285
383
  [
286
- batch_motor_event_indicators,
384
+ torch.concat(motor_tte_event_indicators, dim=0).to(
385
+ torch.bool
386
+ ),
287
387
  torch.full(
288
388
  (padded_length, num_time_pieces, motor_tte_vocab_size),
289
389
  False,
@@ -296,27 +396,17 @@ class CehrGptDataCollator:
296
396
  )
297
397
 
298
398
  # Input to indicate whether the visit should be included for TTE predictions
299
- batch_motor_time_to_event_to_include = torch.concat(
300
- batch_motor_time_to_event_to_include, dim=0
399
+ batch["motor_tte_task_indicators"] = pad_sequence(
400
+ motor_tte_task_indicators,
401
+ batch_first=True,
402
+ padding_value=False,
301
403
  ).to(torch.bool)
302
- batch["motor_time_to_event_to_include"] = (
303
- torch.concat(
304
- [
305
- batch_motor_time_to_event_to_include,
306
- torch.full((padded_length,), False),
307
- ],
308
- dim=0,
309
- ).to(torch.bool)
310
- ).reshape((batch_size, -1))
311
404
 
312
405
  # Motor time indicators that indicate whether there are neither clinical events nor censor events
313
- batch_motor_time_indicators = torch.concat(
314
- batch_motor_time_indicators, dim=0
315
- ).to(torch.bool)
316
- batch["motor_time_indicators"] = (
406
+ batch["motor_tte_masks"] = (
317
407
  torch.concat(
318
408
  [
319
- batch_motor_time_indicators,
409
+ torch.concat(motor_tte_masks, dim=0).to(torch.bool),
320
410
  torch.full(
321
411
  (padded_length, num_time_pieces, motor_tte_vocab_size),
322
412
  False,
@@ -422,564 +512,118 @@ class CehrGptDataCollator:
422
512
 
423
513
  return batch
424
514
 
425
- def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
426
- """
427
- Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
428
-
429
- Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
430
- to the occurrence of future motor-related events.
431
-
432
- Args:
433
- record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
434
- This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
435
-
436
- Returns:
437
- Dict[str, Any]: The updated input record with added keys:
438
- - "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
439
- - "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
440
- """
441
- input_ids = record["input_ids"]
442
- sample_packing = getattr(self, "sample_packing", False)
443
-
444
- if isinstance(input_ids, torch.Tensor):
445
- input_ids = input_ids.detach().tolist()
446
-
447
- # This potentially contains packed samples, we need to handle that
448
- packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
449
- pad_indices = []
450
- if sample_packing:
451
- # We start from the first index
452
- for i in range(len(packed_concept_ids)):
453
- if packed_concept_ids[i] == self.tokenizer.pad_token:
454
- # If we encounter consecutive pads, we should break out of the loop
455
- if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
456
- break
457
- pad_indices.append(i)
458
-
459
- # If we did not find a pad, that means the whole sequence belongs to one sample
460
- if len(pad_indices) == 0:
461
- pad_indices.append(len(packed_concept_ids))
462
-
463
- timepiece_time_to_event_vectors = []
464
- timepiece_event_indicators = []
465
- timepiece_indicators = []
466
- time_to_event_to_includes = []
467
-
468
- for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
469
- concept_ids = packed_concept_ids[start_index:end_index]
470
- if concept_ids[0] == self.tokenizer.pad_token:
471
- concept_ids.pop(0)
472
- time_to_event_vectors = []
473
- global_event_indicators = []
474
-
475
- # First collect TTE data in reverse chronological order
476
- censor_times = []
477
- time_to_event_data: List[Dict[str, int]] = []
478
- time_to_event_dict: Dict[str, int] = {}
479
- time_to_event_to_include: List[bool] = []
480
- next_future_visit_concepts = set()
481
- time_interval = 0
482
-
483
- # Reverse walk through concept_ids to calculate TTE from each [VE] point
484
- for concept_id in reversed(concept_ids):
485
- if is_visit_end(concept_id):
486
- # Update TTE for existing concepts, or add new ones seen in this visit
487
- for existing_concept_id in list(time_to_event_dict.keys()):
488
- if existing_concept_id in next_future_visit_concepts:
489
- time_to_event_dict[existing_concept_id] = time_interval
490
- else:
491
- time_to_event_dict[existing_concept_id] += time_interval
492
-
493
- for next_concept_id in next_future_visit_concepts:
494
- if next_concept_id not in time_to_event_dict:
495
- time_to_event_dict[next_concept_id] = time_interval
496
-
497
- # If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
498
- # previous visit
499
- time_to_event_to_include.append(time_interval > 0)
500
- time_to_event_data.append(copy.deepcopy(time_to_event_dict))
501
- # Record the censor time at the end of the visit
502
- if censor_times:
503
- censor_times.append(censor_times[-1] + time_interval)
504
- else:
505
- censor_times.append(time_interval)
506
- time_interval = 0
507
- next_future_visit_concepts.clear()
508
-
509
- elif is_att_token(concept_id):
510
- time_interval += extract_time_interval_in_days(concept_id)
511
-
512
- elif self.tokenizer.is_motor_time_to_event_code(concept_id):
513
- next_future_visit_concepts.add(concept_id)
514
-
515
- if len(time_to_event_data) == 0:
516
- LOG.info(
517
- "Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
518
- "It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
519
- len(concept_ids),
520
- concept_ids[-10:],
521
- )
522
- continue
523
-
524
- # Reverse back to chronological order for final labels
525
- time_to_event_data.reverse()
526
- censor_times.reverse()
527
- time_to_event_to_include.reverse()
528
-
529
- for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
530
- time_to_event_vector = np.full(
531
- self.tokenizer.motor_tte_vocab_size,
532
- fill_value=censor_time,
533
- dtype=np.int32,
534
- )
535
- event_indicator = np.zeros(
536
- self.tokenizer.motor_tte_vocab_size,
537
- dtype=np.int32,
538
- )
539
- visit_token_ids = [
540
- self.tokenizer.get_motor_token_id(concept_id)
541
- for concept_id in visit_tte_data.keys()
542
- ]
543
- visit_tte_values = list(visit_tte_data.values())
544
-
545
- time_to_event_vector[visit_token_ids] = visit_tte_values
546
- event_indicator[visit_token_ids] = 1 # not censored (event occurred)
547
-
548
- time_to_event_vectors.append(time_to_event_vector)
549
- global_event_indicators.append(event_indicator)
550
-
551
- time_to_event_vectors = np.asarray(time_to_event_vectors)
552
- global_event_indicators = np.asarray(global_event_indicators).astype(bool)
553
- n_visits = len(time_to_event_vectors)
554
-
555
- timepiece_time_to_event_vector = np.full(
556
- (
557
- self.motor_num_time_pieces,
558
- n_visits,
559
- self.tokenizer.motor_tte_vocab_size,
560
- ),
561
- fill_value=0,
562
- dtype=np.int32,
563
- )
564
- timepiece_event_indicator = np.zeros(
565
- (
566
- self.motor_num_time_pieces,
567
- n_visits,
568
- self.tokenizer.motor_tte_vocab_size,
569
- ),
570
- dtype=bool,
571
- )
572
- timepiece_indicator = np.zeros(
573
- (
574
- self.motor_num_time_pieces,
575
- n_visits,
576
- self.tokenizer.motor_tte_vocab_size,
577
- ),
578
- dtype=bool,
579
- )
580
-
581
- # Putting the event time and censor time into the corresponding time bins
582
- for bin_num in range(self.motor_num_time_pieces):
583
- start = self.motor_time_interval * bin_num
584
- end = self.motor_time_interval * (bin_num + 1)
585
- time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
586
- timepiece_time_to_event_vector[bin_num] = time_in_bin
587
- event_indicator = (
588
- global_event_indicators
589
- & (start <= time_to_event_vectors)
590
- & (time_to_event_vectors < end)
591
- )
592
- timepiece_event_indicator[bin_num] = event_indicator
593
- timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
594
-
595
- timepiece_time_to_event_vectors.append(
596
- timepiece_time_to_event_vector.swapaxes(0, 1)
597
- )
598
- timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
599
- timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
600
- time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
601
-
602
- record["time_to_event_vectors"] = np.concatenate(
603
- timepiece_time_to_event_vectors, axis=0
604
- )
605
- record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
606
- record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
607
- record["time_to_event_to_include"] = np.concatenate(
608
- time_to_event_to_includes, axis=0
609
- )
610
- return record
611
-
612
- def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
613
-
614
- if not self.shuffle_records:
615
- return record
616
-
617
- if "record_ranks" not in record:
618
- return record
619
-
620
- sorting_column = record["record_ranks"]
621
- random_order = np.random.rand(len(sorting_column))
622
-
623
- if self.include_values:
624
- iterator = zip(
625
- sorting_column,
626
- random_order,
627
- record["input_ids"],
628
- record["value_indicators"],
629
- record["values"],
630
- )
631
- sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
632
- _, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
633
- *list(sorted_list)
634
- )
635
- record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
636
- record["value_indicators"] = self._convert_to_tensor(
637
- sorted_value_indicators
638
- )
639
- record["values"] = self._convert_to_tensor(sorted_values)
640
- else:
641
- iterator = zip(sorting_column, random_order, record["input_ids"])
642
- sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
643
- _, _, sorted_input_ids = zip(*list(sorted_list))
644
- record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
645
- return record
646
-
647
- def generate_start_end_index(
648
- self,
649
- record: Dict[str, Any],
650
- sample_packing: bool,
651
- max_length_allowed: Optional[int] = None,
652
- ) -> Dict[str, Any]:
653
- """Adding the start and end indices to extract a portion of the patient sequence."""
654
- # concept_ids will be used to for time to event predictions and identifying the visit starts
655
- max_length_allowed = (
656
- self.max_length if max_length_allowed is None else max_length_allowed
657
- )
658
- input_ids = record["input_ids"]
659
- if isinstance(input_ids, torch.Tensor):
660
- input_ids = input_ids.detach().tolist()
661
- concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
662
- seq_length = len(record["input_ids"])
663
-
664
- # Subtract one for the [END] token when sample_packing is not enabled
665
- new_max_length = (
666
- max_length_allowed if sample_packing else max_length_allowed - 1
667
- )
668
-
669
- if self.include_ttv_prediction:
670
- record["time_to_visits"] = torch.concat(
671
- [self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
672
- )
673
-
674
- # If linear token exists, we will use it, otherwise we default to the OOV token
675
- linear_token_id = (
676
- self.tokenizer.linear_token_id
677
- if self.tokenizer.linear_token_id
678
- else self.tokenizer.oov_token_id
679
- )
680
- eos_token = (
681
- linear_token_id
682
- if self.add_linear_prob_token
683
- else self.tokenizer.end_token_id
684
- )
685
-
686
- # Return the record directly if the actual sequence length is less than the max sequence
687
- if seq_length <= new_max_length:
688
- if not sample_packing:
689
- record["input_ids"] = torch.concat(
690
- [
691
- self._convert_to_tensor(record["input_ids"]),
692
- self._convert_to_tensor([eos_token]),
693
- ]
694
- )
695
- if self.include_values:
696
- record["value_indicators"] = torch.concat(
697
- [
698
- self._convert_to_tensor(record["value_indicators"]),
699
- self._convert_to_tensor([False]),
700
- ]
701
- ).to(torch.bool)
702
- record["values"] = torch.concat(
703
- [
704
- self._convert_to_tensor(record["values"]),
705
- self._convert_to_tensor(
706
- [self.tokenizer.pad_value_token_id]
707
- ),
708
- ]
709
- )
710
- if self.include_ttv_prediction:
711
- record["time_to_visits"] = torch.concat(
712
- [
713
- record["time_to_visits"],
714
- self._convert_to_tensor([-100.0]),
715
- ]
716
- )
717
- return record
718
-
719
- if self.pretraining:
720
- # There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
721
- # prompt depending on the new starting point
722
- if random.random() < 0.5 and not sample_packing:
723
- start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
724
- concept_ids, new_max_length
725
- )
726
- if start_index != end_index:
727
- record["input_ids"] = self._convert_to_tensor(
728
- record["input_ids"][start_index : end_index + 1]
729
- )
730
- if self.include_values:
731
- record["value_indicators"] = self._convert_to_tensor(
732
- record["value_indicators"][start_index : end_index + 1]
733
- ).to(torch.bool)
734
- record["values"] = self._convert_to_tensor(
735
- record["values"][start_index : end_index + 1]
736
- )
737
- if self.include_ttv_prediction:
738
- record["time_to_visits"] = self._convert_to_tensor(
739
- self._convert_time_to_event(
740
- concept_ids[start_index : end_index + 1]
741
- )
742
- )
743
- return record
744
-
745
- # The default employs a right truncation strategy, where the demographic prompt is reserved
746
- end_index = new_max_length
747
- for i in reversed(list(range(0, end_index))):
748
- current_token = record["input_ids"][i]
749
- if current_token == self.ve_token_id:
750
- # Plus one because slicing is right exclusive
751
- end_index = i + 1
752
- break
753
-
754
- record["input_ids"] = record["input_ids"][0:end_index]
755
-
756
- # We want to make sure we take the subset of attention_mask in sample packing if this field is available
757
- if sample_packing and "attention_mask" in record:
758
- record["attention_mask"] = record["attention_mask"][0:end_index]
759
-
760
- if sample_packing and "position_ids" in record:
761
- record["position_ids"] = record["position_ids"][0:end_index]
762
-
763
- if self.include_values:
764
- record["value_indicators"] = self._convert_to_tensor(
765
- record["value_indicators"][0:end_index]
766
- ).to(torch.bool)
767
- record["values"] = self._convert_to_tensor(
768
- record["values"][0:end_index]
769
- )
770
- if self.include_ttv_prediction:
771
- record["time_to_visits"] = self._convert_to_tensor(
772
- self._convert_time_to_event(concept_ids[0:end_index])
773
- )
774
- return record
775
- else:
776
- if self.include_demographics and not sample_packing:
777
- # We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
778
- demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
779
- concept_ids
780
- )
781
- for token_index, demographic_prompt in demographic_prompts_at_visits:
782
- if (
783
- seq_length - token_index
784
- <= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
785
- ):
786
- demographic_tokens = self.tokenizer.encode(demographic_prompt)
787
- record["input_ids"] = torch.concat(
788
- [
789
- self._convert_to_tensor(demographic_tokens),
790
- self._convert_to_tensor(
791
- record["input_ids"][token_index:seq_length]
792
- ),
793
- ]
794
- )
795
- if self.include_values:
796
- record["value_indicators"] = torch.concat(
797
- [
798
- torch.zeros(
799
- [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
800
- ).to(torch.bool),
801
- self._convert_to_tensor(
802
- record["value_indicators"][
803
- token_index:seq_length
804
- ]
805
- ),
806
- ]
807
- )
808
- record["values"] = torch.concat(
809
- [
810
- torch.zeros(
811
- [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
812
- )
813
- .to(torch.int32)
814
- .fill_(self.tokenizer.pad_value_token_id),
815
- self._convert_to_tensor(
816
- record["values"][token_index:seq_length]
817
- ),
818
- ]
819
- )
820
- if self.include_ttv_prediction:
821
- record["time_to_visits"] = torch.concat(
822
- [
823
- torch.zeros(
824
- [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
825
- )
826
- .to(torch.float32)
827
- .fill_(-100.0),
828
- record["time_to_visits"][token_index:seq_length],
829
- ]
830
- )
831
- break
832
- else:
833
- start_index = seq_length - new_max_length
834
- end_index = seq_length
835
- for i in range(start_index, end_index):
836
- current_token = record["input_ids"][i]
837
- if current_token == self.vs_token_id:
838
- record["input_ids"] = record["input_ids"][i:end_index]
839
- if sample_packing and "attention_mask" in record:
840
- record["attention_mask"] = record["attention_mask"][
841
- i:end_index
842
- ]
843
- if sample_packing and "position_ids" in record:
844
- record["position_ids"] = record["position_ids"][i:end_index]
845
- if self.include_values:
846
- record["value_indicators"] = record["value_indicators"][
847
- i:end_index
848
- ]
849
- record["values"] = record["values"][i:end_index]
850
- if self.include_ttv_prediction:
851
- record["time_to_visits"] = record["time_to_visits"][
852
- i:end_index
853
- ]
854
- break
855
-
856
- # This could happen when the last visit contains more than new_max_length number of tokens
857
- # We simply take the last new_max_length number of tokens from the patient sequence
858
- if len(record["input_ids"]) > new_max_length:
859
- record["input_ids"] = record["input_ids"][-new_max_length:]
860
- if sample_packing and "attention_mask" in record:
861
- record["attention_mask"] = record["attention_mask"][
862
- -new_max_length:
863
- ]
864
- if sample_packing and "position_ids" in record:
865
- record["position_ids"] = record["position_ids"][-new_max_length:]
866
- if self.include_values:
867
- record["value_indicators"] = record["value_indicators"][
868
- -new_max_length:
869
- ]
870
- record["values"] = record["values"][-new_max_length:]
871
- if self.include_ttv_prediction:
872
- record["time_to_visits"] = record["time_to_visits"][
873
- -new_max_length:
874
- ]
875
-
876
- if not sample_packing:
877
- # Finally we add the end token to the end of the sequence
878
- record["input_ids"] = torch.concat(
879
- [
880
- self._convert_to_tensor(record["input_ids"]),
881
- self._convert_to_tensor([eos_token]),
882
- ]
883
- )
884
- if self.include_values:
885
- record["value_indicators"] = torch.concat(
886
- [
887
- self._convert_to_tensor(record["value_indicators"]),
888
- self._convert_to_tensor([False]),
889
- ]
890
- ).to(torch.bool)
891
- record["values"] = torch.concat(
892
- [
893
- self._convert_to_tensor(record["values"]),
894
- self._convert_to_tensor(
895
- [self.tokenizer.pad_value_token_id]
896
- ),
897
- ]
898
- )
899
- if self.include_ttv_prediction:
900
- record["time_to_visits"] = torch.concat(
901
- [
902
- record["time_to_visits"],
903
- self._convert_to_tensor([-100.0]),
904
- ]
905
- )
906
- return record
907
-
908
515
 
909
516
  class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
910
517
  def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
911
518
  self.max_tokens_per_batch = max_tokens
912
519
  self.max_position_embeddings = max_position_embeddings
913
520
  self.sample_packing = True
914
- self.add_end_token_in_sample_packing = kwargs.pop(
915
- "add_end_token_in_sample_packing", False
916
- )
917
521
  super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
522
+ self.cehrgpt_data_processor.max_length = self.max_position_embeddings
918
523
 
919
524
  def __call__(self, examples):
920
525
  current_input_ids = []
921
526
  current_attention_mask = []
922
- current_position_ids = []
527
+ current_ages = []
528
+ current_epoch_times = []
923
529
  current_value_indicators = []
924
530
  current_values = []
925
531
 
532
+ # MOTOR inputs
533
+ current_motor_censor_times = []
534
+ current_motor_row_indices = []
535
+ current_motor_col_indices = []
536
+ current_motor_values = []
537
+ current_motor_tte_task_indicators = []
538
+
926
539
  # Demographics
927
540
  current_person_ids = []
928
541
  current_index_dates = []
929
542
 
930
543
  # Binary classification inputs
931
- current_ages = []
544
+ current_prediction_ages = []
932
545
  current_labels = []
933
546
 
934
547
  for idx, example in enumerate(examples):
935
-
936
- # We only add an end token if the patient sequence could fit in the entire context window
937
- add_end_token = (
938
- len(example["input_ids"]) <= self.max_position_embeddings
939
- and self.add_end_token_in_sample_packing
940
- )
941
- # If the sample length exceeds the model's capacity, truncate this example
942
- if len(example["input_ids"]) > self.max_position_embeddings:
943
- example = self.generate_start_end_index(
944
- example, False, self.max_position_embeddings
945
- )
946
-
947
- add_eos_token = add_end_token | self.add_linear_prob_token
948
- additional_tokens = []
949
- if add_end_token:
950
- additional_tokens.append(self.tokenizer.end_token_id)
951
- elif self.add_linear_prob_token:
952
- # Backward compatible
953
- linear_prob_token_id = (
954
- self.tokenizer.linear_token_id
955
- if self.tokenizer.linear_token_id is not None
956
- else self.tokenizer.oov_token_id
957
- )
958
- additional_tokens.append(linear_prob_token_id)
959
- additional_tokens.append(self.tokenizer.pad_token_id)
548
+ example = self.cehrgpt_data_processor.transform(example)
960
549
  input_ids = example["input_ids"]
961
550
  # We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
962
551
  # If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
963
552
  # we only add [PAD] token to the end of the sequence because it's not finished
964
- current_input_ids.extend(list(input_ids) + additional_tokens)
965
- current_attention_mask.extend(
966
- np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
553
+ current_input_ids.extend(list(input_ids) + [self.tokenizer.pad_token_id])
554
+ current_attention_mask.extend(np.ones_like(input_ids).tolist() + [0])
555
+
556
+ ages = (
557
+ example["ages"].tolist()
558
+ if isinstance(example["ages"], torch.Tensor)
559
+ else list(example["ages"])
967
560
  )
968
- num_tokens_to_pad = 1 + int(add_eos_token)
969
- current_position_ids.extend(
970
- np.clip(
971
- list(range(len(input_ids) + num_tokens_to_pad)),
972
- 0,
973
- self.max_position_embeddings - 1,
974
- )
561
+ current_ages.extend(ages + [max(ages)])
562
+
563
+ epoch_times = (
564
+ example["epoch_times"].tolist()
565
+ if isinstance(example["epoch_times"], torch.Tensor)
566
+ else list(example["epoch_times"])
975
567
  )
568
+ current_epoch_times.extend(epoch_times + [max(epoch_times)])
569
+
976
570
  if self.include_values:
977
571
  current_value_indicators.extend(
978
- list(example["value_indicators"]) + [False] * num_tokens_to_pad
572
+ (
573
+ example["value_indicators"].tolist()
574
+ if isinstance(example["value_indicators"], torch.Tensor)
575
+ else list(example["value_indicators"])
576
+ )
577
+ + [False]
979
578
  )
980
579
  current_values.extend(
981
- list(example["values"])
982
- + [self.tokenizer.pad_value_token_id] * num_tokens_to_pad
580
+ (
581
+ example["values"].tolist()
582
+ if isinstance(example["values"], torch.Tensor)
583
+ else list(example["values"])
584
+ )
585
+ + [self.tokenizer.pad_value_token_id]
586
+ )
587
+
588
+ if self.include_motor_time_to_event:
589
+ current_max_motor_row_index = len(np.unique(current_motor_row_indices))
590
+ motor_row_indices = (
591
+ example["motor_row_indices"].tolist()
592
+ if isinstance(example["motor_row_indices"], torch.Tensor)
593
+ else list(example["motor_row_indices"])
594
+ )
595
+ current_motor_row_indices.extend(
596
+ list(
597
+ map(
598
+ lambda offset: offset + current_max_motor_row_index,
599
+ motor_row_indices,
600
+ )
601
+ )
602
+ )
603
+ current_motor_col_indices.extend(
604
+ example["motor_col_indices"].tolist()
605
+ if isinstance(example["motor_col_indices"], torch.Tensor)
606
+ else list(example["motor_col_indices"])
607
+ )
608
+ current_motor_values.extend(
609
+ example["motor_values"].tolist()
610
+ if isinstance(example["motor_values"], torch.Tensor)
611
+ else list(example["motor_values"])
612
+ )
613
+ current_motor_censor_times.extend(
614
+ example["motor_censor_times"].tolist()
615
+ if isinstance(example["motor_censor_times"], torch.Tensor)
616
+ else list(example["motor_censor_times"])
617
+ )
618
+ current_motor_tte_task_indicators.extend(
619
+ (
620
+ example["motor_tte_task_indicators"].tolist()
621
+ if isinstance(
622
+ example["motor_tte_task_indicators"], torch.Tensor
623
+ )
624
+ else list(example["motor_tte_task_indicators"])
625
+ )
626
+ + [False]
983
627
  )
984
628
 
985
629
  if "person_id" in example:
@@ -989,7 +633,7 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
989
633
  current_index_dates.append(example["index_date"])
990
634
 
991
635
  if "age_at_index" in example:
992
- current_ages.append(example["age_at_index"])
636
+ current_prediction_ages.append(example["age_at_index"])
993
637
 
994
638
  if "classifier_label" in example:
995
639
  current_labels.append(example["classifier_label"])
@@ -1001,20 +645,33 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
1001
645
  packed_example = {
1002
646
  "input_ids": current_input_ids,
1003
647
  "attention_mask": current_attention_mask,
1004
- "position_ids": current_position_ids,
648
+ "ages": current_ages,
649
+ "epoch_times": current_epoch_times,
1005
650
  }
651
+
1006
652
  if self.include_values:
1007
- packed_example.update({"value_indicators": current_value_indicators})
1008
- packed_example.update({"values": current_values})
653
+ packed_example.update(
654
+ {"value_indicators": current_value_indicators, "values": current_values}
655
+ )
656
+ if self.include_motor_time_to_event:
657
+ packed_example.update(
658
+ {
659
+ "motor_censor_times": current_motor_censor_times,
660
+ "motor_row_indices": current_motor_row_indices,
661
+ "motor_col_indices": current_motor_col_indices,
662
+ "motor_values": current_motor_values,
663
+ "motor_tte_task_indicators": current_motor_tte_task_indicators,
664
+ }
665
+ )
1009
666
 
1010
667
  if current_labels:
1011
668
  packed_example.update(
1012
669
  {
1013
670
  "person_id": current_person_ids,
1014
671
  "index_date": current_index_dates,
1015
- "age_at_index": current_ages,
672
+ "age_at_index": current_prediction_ages,
1016
673
  "classifier_label": current_labels,
1017
674
  }
1018
675
  )
1019
-
676
+ # print(f"Packing examples took {time.time() - start} seconds")
1020
677
  return super().__call__([packed_example])