cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.4.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,13 @@
1
- import copy
2
- import random
3
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Dict, List
4
2
 
5
3
  import numpy as np
6
4
  import torch
7
5
  from torch.nn.utils.rnn import pad_sequence
8
6
  from transformers.utils import logging
9
7
 
10
- from cehrgpt.gpt_utils import (
11
- DEMOGRAPHIC_PROMPT_SIZE,
12
- collect_demographic_prompts_at_visits,
13
- extract_time_interval_in_days,
14
- extract_time_interval_in_hours,
15
- is_att_token,
16
- is_inpatient_att_token,
17
- is_inpatient_hour_token,
18
- is_visit_end,
19
- random_slice_gpt_sequence,
20
- )
8
+ from cehrgpt.data.cehrgpt_data_processor import CehrGptDataProcessor
21
9
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
22
10
 
23
- TIME_TO_EVENT_MAX_TIME = 3650
24
- INPATIENT_STAY_DURATION_LIMIT = 30
25
11
  LOG = logging.get_logger("transformers")
26
12
 
27
13
 
@@ -30,13 +16,14 @@ class CehrGptDataCollator:
30
16
  self,
31
17
  tokenizer: CehrGptTokenizer,
32
18
  max_length: int,
33
- shuffle_records: bool = False,
34
19
  include_values: bool = False,
20
+ shuffle_records: bool = False,
35
21
  include_ttv_prediction: bool = False,
36
22
  use_sub_time_tokenization: bool = False,
37
23
  include_motor_time_to_event: bool = False,
38
24
  motor_tte_vocab_size: int = 0,
39
25
  motor_num_time_pieces: int = 8,
26
+ motor_sampling_probability: float = 0.5,
40
27
  pretraining: bool = True,
41
28
  include_demographics: bool = False,
42
29
  add_linear_prob_token: bool = False,
@@ -47,13 +34,12 @@ class CehrGptDataCollator:
47
34
  self.vs_token_id = tokenizer.vs_token_id
48
35
  self.ve_token_id = tokenizer.ve_token_id
49
36
 
50
- self.shuffle_records = shuffle_records
51
37
  self.include_values = include_values
52
38
  self.include_ttv_prediction = include_ttv_prediction
53
39
  self.use_sub_time_tokenization = use_sub_time_tokenization
54
40
  self.pretraining = pretraining
55
41
  self.include_demographics = include_demographics
56
- self.add_linear_prob_token = add_linear_prob_token
42
+ self.motor_code_cache: Dict[str, List[str]] = dict()
57
43
 
58
44
  # MOTOR TTE configuration
59
45
  if include_motor_time_to_event:
@@ -66,8 +52,14 @@ class CehrGptDataCollator:
66
52
  self.include_motor_time_to_event = include_motor_time_to_event
67
53
  self.motor_tte_vocab_size = motor_tte_vocab_size
68
54
  self.motor_num_time_pieces = motor_num_time_pieces
69
- self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
70
-
55
+ self.motor_time_bins = (
56
+ self.tokenizer.get_motor_time_bins(motor_num_time_pieces)
57
+ if self.include_motor_time_to_event
58
+ else []
59
+ )
60
+ # Convert the time bins to seconds
61
+ self.motor_time_bins = [time_bin * 86400 for time_bin in self.motor_time_bins]
62
+ LOG.info("self.motor_time_bins: %s", self.motor_time_bins)
71
63
  if self.use_sub_time_tokenization:
72
64
  token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
73
65
  if not token_to_time_token_mapping:
@@ -83,6 +75,18 @@ class CehrGptDataCollator:
83
75
  list(token_to_time_token_mapping.values()), dtype=torch.int64
84
76
  )
85
77
 
78
+ self.cehrgpt_data_processor = CehrGptDataProcessor(
79
+ tokenizer=tokenizer,
80
+ max_length=self.max_length,
81
+ shuffle_records=shuffle_records,
82
+ include_ttv_prediction=include_ttv_prediction,
83
+ include_values=include_values,
84
+ include_motor_time_to_event=include_motor_time_to_event,
85
+ motor_sampling_probability=motor_sampling_probability,
86
+ pretraining=pretraining,
87
+ add_linear_prob_token=add_linear_prob_token,
88
+ )
89
+
86
90
  def _try_reverse_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
87
91
  if not self.pretraining:
88
92
  return torch.flip(tensor, dims=[-1])
@@ -95,30 +99,120 @@ class CehrGptDataCollator:
95
99
  else:
96
100
  return torch.tensor(features)
97
101
 
98
- @staticmethod
99
- def _convert_time_to_event(concept_ids):
100
- def default_value(c):
101
- try:
102
- if is_att_token(c):
103
- time_to_visit = extract_time_interval_in_days(c)
104
- if (
105
- is_inpatient_att_token(c)
106
- and time_to_visit > INPATIENT_STAY_DURATION_LIMIT
107
- ):
108
- return -100
109
- return time_to_visit
110
- elif is_inpatient_hour_token(c):
111
- return extract_time_interval_in_hours(c) / 24
112
- return -100
113
- except ValueError:
114
- return -100
115
-
116
- return [float(default_value(_)) for _ in concept_ids]
102
+ def create_time_to_event_tensors_ultra_optimized(
103
+ self, record: Dict[str, Any]
104
+ ) -> Dict[str, Any]:
105
+ """Ultra-optimized version using advanced vectorization techniques."""
106
+ motor_row_indices = record["motor_row_indices"]
107
+ motor_col_indices = record["motor_col_indices"]
108
+ motor_values = record["motor_values"]
109
+ motor_censor_times = record["motor_censor_times"]
110
+
111
+ if len(motor_row_indices) == 0:
112
+ # Handle empty case - use tuples for better performance
113
+ empty_shape = (
114
+ 0,
115
+ self.motor_num_time_pieces,
116
+ self.tokenizer.motor_tte_vocab_size,
117
+ )
118
+ record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
119
+ record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
120
+ record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
121
+ return record
122
+
123
+ # Convert to numpy arrays once and get dimensions
124
+ motor_row_indices = np.asarray(motor_row_indices, dtype=np.int32)
125
+ motor_col_indices = np.asarray(motor_col_indices, dtype=np.int32)
126
+ motor_values = np.asarray(motor_values, dtype=np.float32)
127
+ motor_censor_times = np.asarray(motor_censor_times, dtype=np.float32)
128
+
129
+ n_tte_predictions = len(motor_censor_times) # More direct than unique()
130
+ vocab_size = self.tokenizer.motor_tte_vocab_size
131
+ n_time_pieces = self.motor_num_time_pieces
132
+
133
+ # Create time_vectors more efficiently without broadcasting copy
134
+ time_vectors = np.tile(
135
+ motor_censor_times[:, np.newaxis], (1, vocab_size)
136
+ ).astype(np.float32)
137
+ event_indicators = np.zeros((n_tte_predictions, vocab_size), dtype=bool)
138
+
139
+ # Vectorized assignment (already optimal)
140
+ time_vectors[motor_row_indices, motor_col_indices] = motor_values
141
+ event_indicators[motor_row_indices, motor_col_indices] = True
142
+
143
+ # Early return if no predictions
144
+ if n_tte_predictions == 0:
145
+ empty_shape = (0, n_time_pieces, vocab_size)
146
+ record["motor_tte_times"] = np.zeros(empty_shape, dtype=np.float32)
147
+ record["motor_tte_event_indicators"] = np.zeros(empty_shape, dtype=bool)
148
+ record["motor_tte_masks"] = np.zeros(empty_shape, dtype=bool)
149
+ return record
150
+
151
+ # Cache motor_time_bins as numpy array to avoid repeated conversion
152
+ if not hasattr(self, "_motor_time_bins_array"):
153
+ self._motor_time_bins_array = np.asarray(
154
+ self.motor_time_bins, dtype=np.float32
155
+ )
156
+
157
+ motor_time_bins = self._motor_time_bins_array
158
+ start_times = motor_time_bins[:-1]
159
+ end_times = motor_time_bins[1:]
160
+ bin_widths = end_times - start_times # Pre-compute bin widths
161
+
162
+ # ELIMINATED TRANSPOSE: Compute directly in target shape (n_pred, n_bins, vocab)
163
+ # Reshape for broadcasting in target order
164
+ time_vectors_3d = time_vectors[:, np.newaxis, :] # (n_pred, 1, vocab)
165
+ event_indicators_3d = event_indicators[:, np.newaxis, :] # (n_pred, 1, vocab)
166
+
167
+ # Broadcast time bins to match target shape
168
+ start_times_broadcast = start_times[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
169
+ bin_widths_broadcast = bin_widths[np.newaxis, :, np.newaxis] # (1, n_bins, 1)
170
+
171
+ # Compute directly in target shape (n_pred, n_bins, vocab)
172
+ time_diff = time_vectors_3d - start_times_broadcast
173
+ time_in_bin = np.clip(time_diff, 0, bin_widths_broadcast)
174
+
175
+ # Optimized mask computation
176
+ mask = time_in_bin > 0
177
+
178
+ # More efficient log computation with better constant
179
+ log_constant = 1e-8 # Better numerical stability than 1e-10
180
+ time_in_bin_log = np.where(
181
+ mask, np.log2(np.maximum(time_in_bin, log_constant)), -np.inf
182
+ )
183
+
184
+ # Event indicator computation in target shape
185
+ end_times_broadcast = motor_time_bins[1:][np.newaxis, :, np.newaxis]
186
+ time_in_range = (time_vectors_3d >= start_times_broadcast) & (
187
+ time_vectors_3d < end_times_broadcast
188
+ )
189
+ event_in_bin = event_indicators_3d & time_in_range
190
+
191
+ # Combined mask computation
192
+ final_mask = mask | event_in_bin
193
+
194
+ # Direct assignment - NO TRANSPOSE NEEDED!
195
+ record["motor_tte_times"] = time_in_bin_log
196
+ record["motor_tte_event_indicators"] = event_in_bin
197
+ record["motor_tte_masks"] = final_mask
198
+
199
+ # Validation (keep as is - important for correctness)
200
+ assert (
201
+ sum(record["motor_tte_task_indicators"]) == n_tte_predictions
202
+ ), f'sum(record["motor_tte_task_indicators"]) == n_tte_predictions must be true'
203
+
204
+ # Clean up input data
205
+ del record["motor_row_indices"]
206
+ del record["motor_col_indices"]
207
+ del record["motor_values"]
208
+
209
+ return record
117
210
 
118
211
  def __call__(self, examples):
119
- sample_packing = getattr(self, "sample_packing", False)
120
- examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
121
- examples = [self.random_sort(_) for _ in examples]
212
+
213
+ if not getattr(self, "sample_packing", False):
214
+ examples = [self.cehrgpt_data_processor.transform(_) for _ in examples]
215
+
122
216
  batch = {}
123
217
 
124
218
  # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
@@ -162,37 +256,31 @@ class CehrGptDataCollator:
162
256
  f"batch['input_ids']: {batch['input_ids']} "
163
257
  )
164
258
 
165
- if "epoch_times" in examples[0]:
166
- batch_epoch_times = [
167
- self._try_reverse_tensor(
168
- self._convert_to_tensor(example["epoch_times"])
169
- )
170
- for example in examples
171
- ]
172
- # Pad sequences to the max length in the batch
173
- batch["epoch_times"] = self._try_reverse_tensor(
174
- pad_sequence(
175
- batch_epoch_times,
176
- batch_first=True,
177
- padding_value=0,
178
- ).to(torch.float32)
179
- )
259
+ batch_ages = [
260
+ self._try_reverse_tensor(self._convert_to_tensor(example["ages"]))
261
+ for example in examples
262
+ ]
263
+ # Pad sequences to the max length in the batch
264
+ batch["ages"] = self._try_reverse_tensor(
265
+ pad_sequence(
266
+ batch_ages,
267
+ batch_first=True,
268
+ padding_value=0,
269
+ ).to(torch.int64)
270
+ )
180
271
 
181
- if "position_ids" in examples[0]:
182
- batch_position_ids = [
183
- self._try_reverse_tensor(
184
- self._convert_to_tensor(example["position_ids"])
185
- )
186
- for example in examples
187
- ]
188
- # Pad sequences to the max length in the batch
189
- batch["position_ids"] = self._try_reverse_tensor(
190
- pad_sequence(
191
- batch_position_ids,
192
- batch_first=True,
193
- padding_value=0,
194
- ).to(torch.int64)
195
- )
272
+ batch_epoch_times = [
273
+ self._try_reverse_tensor(self._convert_to_tensor(example["epoch_times"]))
274
+ for example in examples
275
+ ]
276
+ # Pad sequences to the max length in the batch
277
+ batch["epoch_times"] = self._try_reverse_tensor(
278
+ pad_sequence(
279
+ batch_epoch_times,
280
+ batch_first=True,
281
+ padding_value=0,
282
+ ).to(torch.float32)
283
+ )
196
284
 
197
285
  if self.pretraining:
198
286
  batch["labels"] = torch.where(
@@ -233,54 +321,51 @@ class CehrGptDataCollator:
233
321
 
234
322
  if self.include_motor_time_to_event:
235
323
  examples_with_motor_tte = [
236
- self.create_time_to_event_labels(_) for _ in examples
324
+ self.create_time_to_event_tensors_ultra_optimized(_) for _ in examples
237
325
  ]
238
- batch_motor_time_to_event_vectors = [
326
+ # print(f"Creating MOTOR TTE tensors took {time.time() - start} seconds")
327
+ motor_tte_times = [
239
328
  self._try_reverse_tensor(
240
- self._convert_to_tensor(example["time_to_event_vectors"])
329
+ self._convert_to_tensor(example["motor_tte_times"])
241
330
  )
242
331
  for example in examples_with_motor_tte
243
332
  ]
244
- batch_motor_event_indicators = [
333
+ motor_tte_event_indicators = [
245
334
  self._try_reverse_tensor(
246
- self._convert_to_tensor(example["event_indicators"])
335
+ self._convert_to_tensor(example["motor_tte_event_indicators"])
247
336
  )
248
337
  for example in examples_with_motor_tte
249
338
  ]
250
- batch_motor_time_to_event_to_include = [
339
+ motor_tte_task_indicators = [
251
340
  self._try_reverse_tensor(
252
- self._convert_to_tensor(example["time_to_event_to_include"])
341
+ self._convert_to_tensor(example["motor_tte_task_indicators"])
253
342
  )
254
343
  for example in examples_with_motor_tte
255
344
  ]
256
- batch_motor_time_indicators = [
345
+ motor_tte_masks = [
257
346
  self._try_reverse_tensor(
258
- self._convert_to_tensor(example["time_indicators"])
347
+ self._convert_to_tensor(example["motor_tte_masks"])
259
348
  )
260
349
  for example in examples_with_motor_tte
261
350
  ]
262
351
 
263
- batch_motor_time_to_event_vectors = torch.concat(
264
- batch_motor_time_to_event_vectors, dim=0
265
- ).to(torch.float32)
352
+ motor_tte_times = torch.concat(motor_tte_times, dim=0).to(torch.float32)
266
353
 
267
354
  # If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
268
355
  # we only create the labels when any example has more than one visit
269
- if batch_motor_time_to_event_vectors.dim() <= 1:
356
+ if motor_tte_times.dim() <= 1:
270
357
  LOG.warning(
271
358
  "There are no MOTOR TTE labels generated for this batch "
272
359
  "because every example in this batch only contains one visit."
273
360
  )
274
361
  else:
275
362
  batch_size = len(examples)
276
- length, num_time_pieces, motor_tte_vocab_size = (
277
- batch_motor_time_to_event_vectors.shape
278
- )
363
+ length, num_time_pieces, motor_tte_vocab_size = motor_tte_times.shape
279
364
  padded_length = batch_size - length % batch_size
280
- batch["motor_time_to_event_vectors"] = (
365
+ batch["motor_tte_times"] = (
281
366
  torch.concat(
282
367
  [
283
- batch_motor_time_to_event_vectors,
368
+ motor_tte_times,
284
369
  torch.full(
285
370
  (padded_length, num_time_pieces, motor_tte_vocab_size),
286
371
  0.0,
@@ -293,13 +378,12 @@ class CehrGptDataCollator:
293
378
  )
294
379
 
295
380
  # Motor event indicators that indicate there is an event occurred in this time interval
296
- batch_motor_event_indicators = torch.concat(
297
- batch_motor_event_indicators, dim=0
298
- ).to(torch.bool)
299
- batch["motor_event_indicators"] = (
381
+ batch["motor_tte_event_indicators"] = (
300
382
  torch.concat(
301
383
  [
302
- batch_motor_event_indicators,
384
+ torch.concat(motor_tte_event_indicators, dim=0).to(
385
+ torch.bool
386
+ ),
303
387
  torch.full(
304
388
  (padded_length, num_time_pieces, motor_tte_vocab_size),
305
389
  False,
@@ -312,27 +396,17 @@ class CehrGptDataCollator:
312
396
  )
313
397
 
314
398
  # Input to indicate whether the visit should be included for TTE predictions
315
- batch_motor_time_to_event_to_include = torch.concat(
316
- batch_motor_time_to_event_to_include, dim=0
399
+ batch["motor_tte_task_indicators"] = pad_sequence(
400
+ motor_tte_task_indicators,
401
+ batch_first=True,
402
+ padding_value=False,
317
403
  ).to(torch.bool)
318
- batch["motor_time_to_event_to_include"] = (
319
- torch.concat(
320
- [
321
- batch_motor_time_to_event_to_include,
322
- torch.full((padded_length,), False),
323
- ],
324
- dim=0,
325
- ).to(torch.bool)
326
- ).reshape((batch_size, -1))
327
404
 
328
405
  # Motor time indicators that indicate whether there are neither clinical events nor censor events
329
- batch_motor_time_indicators = torch.concat(
330
- batch_motor_time_indicators, dim=0
331
- ).to(torch.bool)
332
- batch["motor_time_indicators"] = (
406
+ batch["motor_tte_masks"] = (
333
407
  torch.concat(
334
408
  [
335
- batch_motor_time_indicators,
409
+ torch.concat(motor_tte_masks, dim=0).to(torch.bool),
336
410
  torch.full(
337
411
  (padded_length, num_time_pieces, motor_tte_vocab_size),
338
412
  False,
@@ -438,572 +512,118 @@ class CehrGptDataCollator:
438
512
 
439
513
  return batch
440
514
 
441
- def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
442
- """
443
- Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
444
-
445
- Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
446
- to the occurrence of future motor-related events.
447
-
448
- Args:
449
- record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
450
- This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
451
-
452
- Returns:
453
- Dict[str, Any]: The updated input record with added keys:
454
- - "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
455
- - "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
456
- """
457
- input_ids = record["input_ids"]
458
- sample_packing = getattr(self, "sample_packing", False)
459
-
460
- if isinstance(input_ids, torch.Tensor):
461
- input_ids = input_ids.detach().tolist()
462
-
463
- # This potentially contains packed samples, we need to handle that
464
- packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
465
- pad_indices = []
466
- if sample_packing:
467
- # We start from the first index
468
- for i in range(len(packed_concept_ids)):
469
- if packed_concept_ids[i] == self.tokenizer.pad_token:
470
- # If we encounter consecutive pads, we should break out of the loop
471
- if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
472
- break
473
- pad_indices.append(i)
474
-
475
- # If we did not find a pad, that means the whole sequence belongs to one sample
476
- if len(pad_indices) == 0:
477
- pad_indices.append(len(packed_concept_ids))
478
-
479
- timepiece_time_to_event_vectors = []
480
- timepiece_event_indicators = []
481
- timepiece_indicators = []
482
- time_to_event_to_includes = []
483
-
484
- for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
485
- concept_ids = packed_concept_ids[start_index:end_index]
486
- if concept_ids[0] == self.tokenizer.pad_token:
487
- concept_ids.pop(0)
488
- time_to_event_vectors = []
489
- global_event_indicators = []
490
-
491
- # First collect TTE data in reverse chronological order
492
- censor_times = []
493
- time_to_event_data: List[Dict[str, int]] = []
494
- time_to_event_dict: Dict[str, int] = {}
495
- time_to_event_to_include: List[bool] = []
496
- next_future_visit_concepts = set()
497
- time_interval = 0
498
-
499
- # Reverse walk through concept_ids to calculate TTE from each [VE] point
500
- for concept_id in reversed(concept_ids):
501
- if is_visit_end(concept_id):
502
- # Update TTE for existing concepts, or add new ones seen in this visit
503
- for existing_concept_id in list(time_to_event_dict.keys()):
504
- if existing_concept_id in next_future_visit_concepts:
505
- time_to_event_dict[existing_concept_id] = time_interval
506
- else:
507
- time_to_event_dict[existing_concept_id] += time_interval
508
-
509
- for next_concept_id in next_future_visit_concepts:
510
- if next_concept_id not in time_to_event_dict:
511
- time_to_event_dict[next_concept_id] = time_interval
512
-
513
- # If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
514
- # previous visit
515
- time_to_event_to_include.append(time_interval > 0)
516
- time_to_event_data.append(copy.deepcopy(time_to_event_dict))
517
- # Record the censor time at the end of the visit
518
- if censor_times:
519
- censor_times.append(censor_times[-1] + time_interval)
520
- else:
521
- censor_times.append(time_interval)
522
- time_interval = 0
523
- next_future_visit_concepts.clear()
524
-
525
- elif is_att_token(concept_id):
526
- time_interval += extract_time_interval_in_days(concept_id)
527
-
528
- elif self.tokenizer.is_motor_time_to_event_code(concept_id):
529
- next_future_visit_concepts.add(concept_id)
530
-
531
- if len(time_to_event_data) == 0:
532
- LOG.info(
533
- "Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
534
- "It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
535
- len(concept_ids),
536
- concept_ids[-10:],
537
- )
538
- continue
539
-
540
- # Reverse back to chronological order for final labels
541
- time_to_event_data.reverse()
542
- censor_times.reverse()
543
- time_to_event_to_include.reverse()
544
-
545
- for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
546
- time_to_event_vector = np.full(
547
- self.tokenizer.motor_tte_vocab_size,
548
- fill_value=censor_time,
549
- dtype=np.int32,
550
- )
551
- event_indicator = np.zeros(
552
- self.tokenizer.motor_tte_vocab_size,
553
- dtype=np.int32,
554
- )
555
- visit_token_ids = [
556
- self.tokenizer.get_motor_token_id(concept_id)
557
- for concept_id in visit_tte_data.keys()
558
- ]
559
- visit_tte_values = list(visit_tte_data.values())
560
-
561
- time_to_event_vector[visit_token_ids] = visit_tte_values
562
- event_indicator[visit_token_ids] = 1 # not censored (event occurred)
563
-
564
- time_to_event_vectors.append(time_to_event_vector)
565
- global_event_indicators.append(event_indicator)
566
-
567
- time_to_event_vectors = np.asarray(time_to_event_vectors)
568
- global_event_indicators = np.asarray(global_event_indicators).astype(bool)
569
- n_visits = len(time_to_event_vectors)
570
-
571
- timepiece_time_to_event_vector = np.full(
572
- (
573
- self.motor_num_time_pieces,
574
- n_visits,
575
- self.tokenizer.motor_tte_vocab_size,
576
- ),
577
- fill_value=0,
578
- dtype=np.int32,
579
- )
580
- timepiece_event_indicator = np.zeros(
581
- (
582
- self.motor_num_time_pieces,
583
- n_visits,
584
- self.tokenizer.motor_tte_vocab_size,
585
- ),
586
- dtype=bool,
587
- )
588
- timepiece_indicator = np.zeros(
589
- (
590
- self.motor_num_time_pieces,
591
- n_visits,
592
- self.tokenizer.motor_tte_vocab_size,
593
- ),
594
- dtype=bool,
595
- )
596
-
597
- # Putting the event time and censor time into the corresponding time bins
598
- for bin_num in range(self.motor_num_time_pieces):
599
- start = self.motor_time_interval * bin_num
600
- end = self.motor_time_interval * (bin_num + 1)
601
- time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
602
- timepiece_time_to_event_vector[bin_num] = time_in_bin
603
- event_indicator = (
604
- global_event_indicators
605
- & (start <= time_to_event_vectors)
606
- & (time_to_event_vectors < end)
607
- )
608
- timepiece_event_indicator[bin_num] = event_indicator
609
- timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
610
-
611
- timepiece_time_to_event_vectors.append(
612
- timepiece_time_to_event_vector.swapaxes(0, 1)
613
- )
614
- timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
615
- timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
616
- time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
617
-
618
- record["time_to_event_vectors"] = np.concatenate(
619
- timepiece_time_to_event_vectors, axis=0
620
- )
621
- record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
622
- record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
623
- record["time_to_event_to_include"] = np.concatenate(
624
- time_to_event_to_includes, axis=0
625
- )
626
- return record
627
-
628
- def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
629
-
630
- if not self.shuffle_records:
631
- return record
632
-
633
- if "record_ranks" not in record:
634
- return record
635
-
636
- sorting_column = record["record_ranks"]
637
- random_order = np.random.rand(len(sorting_column))
638
-
639
- if self.include_values:
640
- iterator = zip(
641
- sorting_column,
642
- random_order,
643
- record["input_ids"],
644
- record["value_indicators"],
645
- record["values"],
646
- )
647
- sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
648
- _, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
649
- *list(sorted_list)
650
- )
651
- record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
652
- record["value_indicators"] = self._convert_to_tensor(
653
- sorted_value_indicators
654
- )
655
- record["values"] = self._convert_to_tensor(sorted_values)
656
- else:
657
- iterator = zip(sorting_column, random_order, record["input_ids"])
658
- sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
659
- _, _, sorted_input_ids = zip(*list(sorted_list))
660
- record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
661
- return record
662
-
663
- def generate_start_end_index(
664
- self,
665
- record: Dict[str, Any],
666
- sample_packing: bool,
667
- max_length_allowed: Optional[int] = None,
668
- ) -> Dict[str, Any]:
669
- """Adding the start and end indices to extract a portion of the patient sequence."""
670
- # concept_ids will be used to for time to event predictions and identifying the visit starts
671
- max_length_allowed = (
672
- self.max_length if max_length_allowed is None else max_length_allowed
673
- )
674
- input_ids = record["input_ids"]
675
- if isinstance(input_ids, torch.Tensor):
676
- input_ids = input_ids.detach().tolist()
677
- concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
678
- seq_length = len(record["input_ids"])
679
-
680
- # Subtract one for the [END] token when sample_packing is not enabled
681
- new_max_length = (
682
- max_length_allowed - 1
683
- if not sample_packing and self.pretraining
684
- else max_length_allowed
685
- )
686
-
687
- if self.include_ttv_prediction:
688
- record["time_to_visits"] = torch.concat(
689
- [self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
690
- )
691
-
692
- # If linear token exists, we will use it, otherwise we default to the OOV token
693
- linear_token_id = (
694
- self.tokenizer.linear_token_id
695
- if self.tokenizer.linear_token_id
696
- else self.tokenizer.oov_token_id
697
- )
698
- eos_token = (
699
- linear_token_id
700
- if self.add_linear_prob_token
701
- else self.tokenizer.end_token_id
702
- )
703
-
704
- # Return the record directly if the actual sequence length is less than the max sequence
705
- if seq_length <= new_max_length:
706
- if not sample_packing and self.pretraining:
707
- record["input_ids"] = torch.concat(
708
- [
709
- self._convert_to_tensor(record["input_ids"]),
710
- self._convert_to_tensor([eos_token]),
711
- ]
712
- )
713
- if "epoch_times" in record:
714
- record["epoch_times"] = torch.concat(
715
- [
716
- self._convert_to_tensor(record["epoch_times"]),
717
- self._convert_to_tensor([record["epoch_times"][-1]]),
718
- ]
719
- )
720
- if self.include_values:
721
- record["value_indicators"] = torch.concat(
722
- [
723
- self._convert_to_tensor(record["value_indicators"]),
724
- self._convert_to_tensor([False]),
725
- ]
726
- ).to(torch.bool)
727
- record["values"] = torch.concat(
728
- [
729
- self._convert_to_tensor(record["values"]),
730
- self._convert_to_tensor(
731
- [self.tokenizer.pad_value_token_id]
732
- ),
733
- ]
734
- )
735
- if self.include_ttv_prediction:
736
- record["time_to_visits"] = torch.concat(
737
- [
738
- record["time_to_visits"],
739
- self._convert_to_tensor([-100.0]),
740
- ]
741
- )
742
- return record
743
-
744
- if self.pretraining:
745
- # There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
746
- # prompt depending on the new starting point
747
- if random.random() < 0.5 and not sample_packing:
748
- start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
749
- concept_ids, new_max_length
750
- )
751
- if start_index != end_index:
752
- record["input_ids"] = self._convert_to_tensor(
753
- record["input_ids"][start_index : end_index + 1]
754
- )
755
- if "epoch_times" in record:
756
- record["epoch_times"] = self._convert_to_tensor(
757
- record["epoch_times"][start_index : end_index + 1]
758
- )
759
- if self.include_values:
760
- record["value_indicators"] = self._convert_to_tensor(
761
- record["value_indicators"][start_index : end_index + 1]
762
- ).to(torch.bool)
763
- record["values"] = self._convert_to_tensor(
764
- record["values"][start_index : end_index + 1]
765
- )
766
- if self.include_ttv_prediction:
767
- record["time_to_visits"] = self._convert_to_tensor(
768
- self._convert_time_to_event(
769
- concept_ids[start_index : end_index + 1]
770
- )
771
- )
772
- return record
773
-
774
- # The default employs a right truncation strategy, where the demographic prompt is reserved
775
- end_index = new_max_length
776
- for i in reversed(list(range(0, end_index))):
777
- current_token = record["input_ids"][i]
778
- if current_token == self.ve_token_id:
779
- # Plus one because slicing is right exclusive
780
- end_index = i + 1
781
- break
782
-
783
- record["input_ids"] = record["input_ids"][0:end_index]
784
-
785
- # We want to make sure we take the subset of attention_mask in sample packing if this field is available
786
- if sample_packing and "attention_mask" in record:
787
- record["attention_mask"] = record["attention_mask"][0:end_index]
788
-
789
- if sample_packing and "position_ids" in record:
790
- record["position_ids"] = record["position_ids"][0:end_index]
791
-
792
- if "epoch_times" in record:
793
- record["epoch_times"] = self._convert_to_tensor(
794
- record["epoch_times"][0:end_index]
795
- )
796
-
797
- if self.include_values:
798
- record["value_indicators"] = self._convert_to_tensor(
799
- record["value_indicators"][0:end_index]
800
- ).to(torch.bool)
801
- record["values"] = self._convert_to_tensor(
802
- record["values"][0:end_index]
803
- )
804
- if self.include_ttv_prediction:
805
- record["time_to_visits"] = self._convert_to_tensor(
806
- self._convert_time_to_event(concept_ids[0:end_index])
807
- )
808
- return record
809
- else:
810
- if self.include_demographics and not sample_packing:
811
- # We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
812
- demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
813
- concept_ids
814
- )
815
- for token_index, demographic_prompt in demographic_prompts_at_visits:
816
- if (
817
- seq_length - token_index
818
- <= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
819
- ):
820
- demographic_tokens = self.tokenizer.encode(demographic_prompt)
821
- record["input_ids"] = torch.concat(
822
- [
823
- self._convert_to_tensor(demographic_tokens),
824
- self._convert_to_tensor(
825
- record["input_ids"][token_index:seq_length]
826
- ),
827
- ]
828
- )
829
- if "epoch_times" in record:
830
- record["epoch_times"] = torch.concat(
831
- [
832
- torch.zeros(
833
- [record["epoch_times"][0]], dtype=torch.float32
834
- ),
835
- self._convert_to_tensor(
836
- record["epoch_times"][token_index:seq_length]
837
- ),
838
- ]
839
- )
840
- if self.include_values:
841
- record["value_indicators"] = torch.concat(
842
- [
843
- torch.zeros(
844
- [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
845
- ).to(torch.bool),
846
- self._convert_to_tensor(
847
- record["value_indicators"][
848
- token_index:seq_length
849
- ]
850
- ),
851
- ]
852
- )
853
- record["values"] = torch.concat(
854
- [
855
- torch.zeros(
856
- [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
857
- )
858
- .to(torch.int32)
859
- .fill_(self.tokenizer.pad_value_token_id),
860
- self._convert_to_tensor(
861
- record["values"][token_index:seq_length]
862
- ),
863
- ]
864
- )
865
- if self.include_ttv_prediction:
866
- record["time_to_visits"] = torch.concat(
867
- [
868
- torch.zeros(
869
- [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
870
- )
871
- .to(torch.float32)
872
- .fill_(-100.0),
873
- record["time_to_visits"][token_index:seq_length],
874
- ]
875
- )
876
- break
877
- else:
878
- start_index = max(seq_length - new_max_length, 0)
879
- end_index = seq_length
880
- for i in range(start_index, end_index):
881
- current_token = record["input_ids"][i]
882
- if current_token == self.vs_token_id:
883
- record["input_ids"] = record["input_ids"][i:end_index]
884
- if sample_packing and "attention_mask" in record:
885
- record["attention_mask"] = record["attention_mask"][
886
- i:end_index
887
- ]
888
- if sample_packing and "position_ids" in record:
889
- record["position_ids"] = record["position_ids"][i:end_index]
890
-
891
- if "epoch_times" in record:
892
- record["epoch_times"] = self._convert_to_tensor(
893
- record["epoch_times"][i:end_index]
894
- )
895
- if self.include_values:
896
- record["value_indicators"] = record["value_indicators"][
897
- i:end_index
898
- ]
899
- record["values"] = record["values"][i:end_index]
900
- if self.include_ttv_prediction:
901
- record["time_to_visits"] = record["time_to_visits"][
902
- i:end_index
903
- ]
904
- break
905
-
906
- # This could happen when the last visit contains more than new_max_length number of tokens
907
- # We simply take the last new_max_length number of tokens from the patient sequence
908
- if len(record["input_ids"]) > new_max_length:
909
- record["input_ids"] = record["input_ids"][-new_max_length:]
910
- if sample_packing and "attention_mask" in record:
911
- record["attention_mask"] = record["attention_mask"][
912
- -new_max_length:
913
- ]
914
- if sample_packing and "position_ids" in record:
915
- record["position_ids"] = record["position_ids"][-new_max_length:]
916
- if "epoch_times" in record:
917
- record["epoch_times"] = self._convert_to_tensor(
918
- record["epoch_times"][-new_max_length:]
919
- )
920
- if self.include_values:
921
- record["value_indicators"] = record["value_indicators"][
922
- -new_max_length:
923
- ]
924
- record["values"] = record["values"][-new_max_length:]
925
- if self.include_ttv_prediction:
926
- record["time_to_visits"] = record["time_to_visits"][
927
- -new_max_length:
928
- ]
929
-
930
- return record
931
-
932
515
 
933
516
  class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
934
517
  def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
935
518
  self.max_tokens_per_batch = max_tokens
936
519
  self.max_position_embeddings = max_position_embeddings
937
520
  self.sample_packing = True
938
- self.add_end_token_in_sample_packing = kwargs.pop(
939
- "add_end_token_in_sample_packing", False
940
- )
941
521
  super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
522
+ self.cehrgpt_data_processor.max_length = self.max_position_embeddings
942
523
 
943
524
  def __call__(self, examples):
944
525
  current_input_ids = []
945
526
  current_attention_mask = []
946
- current_position_ids = []
527
+ current_ages = []
528
+ current_epoch_times = []
947
529
  current_value_indicators = []
948
530
  current_values = []
949
531
 
532
+ # MOTOR inputs
533
+ current_motor_censor_times = []
534
+ current_motor_row_indices = []
535
+ current_motor_col_indices = []
536
+ current_motor_values = []
537
+ current_motor_tte_task_indicators = []
538
+
950
539
  # Demographics
951
540
  current_person_ids = []
952
541
  current_index_dates = []
953
542
 
954
543
  # Binary classification inputs
955
- current_ages = []
544
+ current_prediction_ages = []
956
545
  current_labels = []
957
546
 
958
547
  for idx, example in enumerate(examples):
959
-
960
- # We only add an end token if the patient sequence could fit in the entire context window
961
- add_end_token = (
962
- len(example["input_ids"]) <= self.max_position_embeddings
963
- and self.add_end_token_in_sample_packing
964
- )
965
- # If the sample length exceeds the model's capacity, truncate this example
966
- if len(example["input_ids"]) > self.max_position_embeddings:
967
- example = self.generate_start_end_index(
968
- example, False, self.max_position_embeddings
969
- )
970
-
971
- add_eos_token = add_end_token | self.add_linear_prob_token
972
- additional_tokens = []
973
- if add_end_token:
974
- additional_tokens.append(self.tokenizer.end_token_id)
975
- elif self.add_linear_prob_token:
976
- # Backward compatible
977
- linear_prob_token_id = (
978
- self.tokenizer.linear_token_id
979
- if self.tokenizer.linear_token_id is not None
980
- else self.tokenizer.oov_token_id
981
- )
982
- additional_tokens.append(linear_prob_token_id)
983
- additional_tokens.append(self.tokenizer.pad_token_id)
548
+ example = self.cehrgpt_data_processor.transform(example)
984
549
  input_ids = example["input_ids"]
985
550
  # We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
986
551
  # If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
987
552
  # we only add [PAD] token to the end of the sequence because it's not finished
988
- current_input_ids.extend(list(input_ids) + additional_tokens)
989
- current_attention_mask.extend(
990
- np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
553
+ current_input_ids.extend(list(input_ids) + [self.tokenizer.pad_token_id])
554
+ current_attention_mask.extend(np.ones_like(input_ids).tolist() + [0])
555
+
556
+ ages = (
557
+ example["ages"].tolist()
558
+ if isinstance(example["ages"], torch.Tensor)
559
+ else list(example["ages"])
991
560
  )
992
- num_tokens_to_pad = 1 + int(add_eos_token)
993
- current_position_ids.extend(
994
- np.clip(
995
- list(range(len(input_ids) + num_tokens_to_pad)),
996
- 0,
997
- self.max_position_embeddings - 1,
998
- )
561
+ current_ages.extend(ages + [max(ages)])
562
+
563
+ epoch_times = (
564
+ example["epoch_times"].tolist()
565
+ if isinstance(example["epoch_times"], torch.Tensor)
566
+ else list(example["epoch_times"])
999
567
  )
568
+ current_epoch_times.extend(epoch_times + [max(epoch_times)])
569
+
1000
570
  if self.include_values:
1001
571
  current_value_indicators.extend(
1002
- list(example["value_indicators"]) + [False] * num_tokens_to_pad
572
+ (
573
+ example["value_indicators"].tolist()
574
+ if isinstance(example["value_indicators"], torch.Tensor)
575
+ else list(example["value_indicators"])
576
+ )
577
+ + [False]
1003
578
  )
1004
579
  current_values.extend(
1005
- list(example["values"])
1006
- + [self.tokenizer.pad_value_token_id] * num_tokens_to_pad
580
+ (
581
+ example["values"].tolist()
582
+ if isinstance(example["values"], torch.Tensor)
583
+ else list(example["values"])
584
+ )
585
+ + [self.tokenizer.pad_value_token_id]
586
+ )
587
+
588
+ if self.include_motor_time_to_event:
589
+ current_max_motor_row_index = len(np.unique(current_motor_row_indices))
590
+ motor_row_indices = (
591
+ example["motor_row_indices"].tolist()
592
+ if isinstance(example["motor_row_indices"], torch.Tensor)
593
+ else list(example["motor_row_indices"])
594
+ )
595
+ current_motor_row_indices.extend(
596
+ list(
597
+ map(
598
+ lambda offset: offset + current_max_motor_row_index,
599
+ motor_row_indices,
600
+ )
601
+ )
602
+ )
603
+ current_motor_col_indices.extend(
604
+ example["motor_col_indices"].tolist()
605
+ if isinstance(example["motor_col_indices"], torch.Tensor)
606
+ else list(example["motor_col_indices"])
607
+ )
608
+ current_motor_values.extend(
609
+ example["motor_values"].tolist()
610
+ if isinstance(example["motor_values"], torch.Tensor)
611
+ else list(example["motor_values"])
612
+ )
613
+ current_motor_censor_times.extend(
614
+ example["motor_censor_times"].tolist()
615
+ if isinstance(example["motor_censor_times"], torch.Tensor)
616
+ else list(example["motor_censor_times"])
617
+ )
618
+ current_motor_tte_task_indicators.extend(
619
+ (
620
+ example["motor_tte_task_indicators"].tolist()
621
+ if isinstance(
622
+ example["motor_tte_task_indicators"], torch.Tensor
623
+ )
624
+ else list(example["motor_tte_task_indicators"])
625
+ )
626
+ + [False]
1007
627
  )
1008
628
 
1009
629
  if "person_id" in example:
@@ -1013,7 +633,7 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
1013
633
  current_index_dates.append(example["index_date"])
1014
634
 
1015
635
  if "age_at_index" in example:
1016
- current_ages.append(example["age_at_index"])
636
+ current_prediction_ages.append(example["age_at_index"])
1017
637
 
1018
638
  if "classifier_label" in example:
1019
639
  current_labels.append(example["classifier_label"])
@@ -1025,20 +645,33 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
1025
645
  packed_example = {
1026
646
  "input_ids": current_input_ids,
1027
647
  "attention_mask": current_attention_mask,
1028
- "position_ids": current_position_ids,
648
+ "ages": current_ages,
649
+ "epoch_times": current_epoch_times,
1029
650
  }
651
+
1030
652
  if self.include_values:
1031
- packed_example.update({"value_indicators": current_value_indicators})
1032
- packed_example.update({"values": current_values})
653
+ packed_example.update(
654
+ {"value_indicators": current_value_indicators, "values": current_values}
655
+ )
656
+ if self.include_motor_time_to_event:
657
+ packed_example.update(
658
+ {
659
+ "motor_censor_times": current_motor_censor_times,
660
+ "motor_row_indices": current_motor_row_indices,
661
+ "motor_col_indices": current_motor_col_indices,
662
+ "motor_values": current_motor_values,
663
+ "motor_tte_task_indicators": current_motor_tte_task_indicators,
664
+ }
665
+ )
1033
666
 
1034
667
  if current_labels:
1035
668
  packed_example.update(
1036
669
  {
1037
670
  "person_id": current_person_ids,
1038
671
  "index_date": current_index_dates,
1039
- "age_at_index": current_ages,
672
+ "age_at_index": current_prediction_ages,
1040
673
  "classifier_label": current_labels,
1041
674
  }
1042
675
  )
1043
-
676
+ # print(f"Packing examples took {time.time() - start} seconds")
1044
677
  return super().__call__([packed_example])