cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,28 @@
1
+ import copy
1
2
  import random
2
- from typing import Any, Dict
3
+ from typing import Any, Dict, List, Optional
3
4
 
4
5
  import numpy as np
5
6
  import torch
6
7
  from torch.nn.utils.rnn import pad_sequence
8
+ from transformers.utils import logging
7
9
 
8
10
  from cehrgpt.gpt_utils import (
9
11
  DEMOGRAPHIC_PROMPT_SIZE,
10
12
  collect_demographic_prompts_at_visits,
11
13
  extract_time_interval_in_days,
14
+ extract_time_interval_in_hours,
12
15
  is_att_token,
13
16
  is_inpatient_att_token,
17
+ is_inpatient_hour_token,
18
+ is_visit_end,
14
19
  random_slice_gpt_sequence,
15
20
  )
16
21
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
17
22
 
23
+ TIME_TO_EVENT_MAX_TIME = 3650
18
24
  INPATIENT_STAY_DURATION_LIMIT = 30
25
+ LOG = logging.get_logger("transformers")
19
26
 
20
27
 
21
28
  class CehrGptDataCollator:
@@ -27,20 +34,18 @@ class CehrGptDataCollator:
27
34
  include_values: bool = False,
28
35
  include_ttv_prediction: bool = False,
29
36
  use_sub_time_tokenization: bool = False,
37
+ include_motor_time_to_event: bool = False,
38
+ motor_tte_vocab_size: int = 0,
39
+ motor_num_time_pieces: int = 8,
30
40
  pretraining: bool = True,
31
41
  include_demographics: bool = False,
42
+ add_linear_prob_token: bool = False,
32
43
  ):
33
44
  self.tokenizer = tokenizer
34
45
  self.max_length = max_length
35
- # Pre-compute these so we can use them later on
36
- # We used VS for the historical data, currently, we use the new [VS] for the newer data
37
- # so we need to check both cases.
38
- self.vs_token_id = tokenizer._convert_token_to_id("VS")
39
- if self.vs_token_id == tokenizer._oov_token_id:
40
- self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
41
- self.ve_token_id = tokenizer._convert_token_to_id("VE")
42
- if self.ve_token_id == tokenizer._oov_token_id:
43
- self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
46
+
47
+ self.vs_token_id = tokenizer.vs_token_id
48
+ self.ve_token_id = tokenizer.ve_token_id
44
49
 
45
50
  self.shuffle_records = shuffle_records
46
51
  self.include_values = include_values
@@ -48,6 +53,20 @@ class CehrGptDataCollator:
48
53
  self.use_sub_time_tokenization = use_sub_time_tokenization
49
54
  self.pretraining = pretraining
50
55
  self.include_demographics = include_demographics
56
+ self.add_linear_prob_token = add_linear_prob_token
57
+
58
+ # MOTOR TTE configuration
59
+ if include_motor_time_to_event:
60
+ assert motor_tte_vocab_size > 0, (
61
+ f"motor_tte_vocab_size must be greater than 0 "
62
+ f"when include_motor_time_to_event is set to True. "
63
+ f"But motor_tte_vocab_size: {motor_tte_vocab_size} is provided"
64
+ )
65
+
66
+ self.include_motor_time_to_event = include_motor_time_to_event
67
+ self.motor_tte_vocab_size = motor_tte_vocab_size
68
+ self.motor_num_time_pieces = motor_num_time_pieces
69
+ self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
51
70
 
52
71
  if self.use_sub_time_tokenization:
53
72
  token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
@@ -88,6 +107,8 @@ class CehrGptDataCollator:
88
107
  ):
89
108
  return -100
90
109
  return time_to_visit
110
+ elif is_inpatient_hour_token(c):
111
+ return extract_time_interval_in_hours(c) / 24
91
112
  return -100
92
113
  except ValueError:
93
114
  return -100
@@ -95,8 +116,8 @@ class CehrGptDataCollator:
95
116
  return [float(default_value(_)) for _ in concept_ids]
96
117
 
97
118
  def __call__(self, examples):
98
-
99
- examples = [self.generate_start_end_index(_) for _ in examples]
119
+ sample_packing = getattr(self, "sample_packing", False)
120
+ examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
100
121
  examples = [self.random_sort(_) for _ in examples]
101
122
  batch = {}
102
123
 
@@ -105,9 +126,12 @@ class CehrGptDataCollator:
105
126
  self._try_reverse_tensor(self._convert_to_tensor(example["input_ids"]))
106
127
  for example in examples
107
128
  ]
129
+
108
130
  batch_attention_mask = [
109
131
  self._try_reverse_tensor(
110
- torch.ones_like(
132
+ self._convert_to_tensor(example["attention_mask"]).to(torch.float)
133
+ if "attention_mask" in example
134
+ else torch.ones_like(
111
135
  self._convert_to_tensor(example["input_ids"]), dtype=torch.float
112
136
  )
113
137
  )
@@ -128,16 +152,40 @@ class CehrGptDataCollator:
128
152
  )
129
153
  assert batch["input_ids"].shape[1] <= self.max_length
130
154
  assert batch["attention_mask"].shape[1] <= self.max_length
155
+ assert batch["attention_mask"].shape[1] == batch["input_ids"].shape[1], (
156
+ f'batch["attention_mask"].shape[1]: {batch["attention_mask"].shape[1]}, '
157
+ f'batch["input_ids"].shape[1]: {batch["input_ids"].shape[1]}'
158
+ )
159
+ assert batch["input_ids"].max() < self.tokenizer.vocab_size, (
160
+ f"batch['input_ids'].max(): {batch['input_ids'].max()} must be smaller than "
161
+ f"self.tokenizer.vocab_size: {self.tokenizer.vocab_size}. "
162
+ f"batch['input_ids']: {batch['input_ids']} "
163
+ )
131
164
 
132
- if self.pretraining:
133
- batch["labels"] = self._try_reverse_tensor(
165
+ if "position_ids" in examples[0]:
166
+ batch_position_ids = [
167
+ self._try_reverse_tensor(
168
+ self._convert_to_tensor(example["position_ids"])
169
+ )
170
+ for example in examples
171
+ ]
172
+ # Pad sequences to the max length in the batch
173
+ batch["position_ids"] = self._try_reverse_tensor(
134
174
  pad_sequence(
135
- batch_input_ids,
175
+ batch_position_ids,
136
176
  batch_first=True,
137
- padding_value=-100,
177
+ padding_value=0,
138
178
  ).to(torch.int64)
139
179
  )
140
180
 
181
+ if self.pretraining:
182
+ batch["labels"] = torch.where(
183
+ (batch["input_ids"] != self.tokenizer.pad_token_id)
184
+ & batch["attention_mask"].to(torch.bool),
185
+ batch["input_ids"],
186
+ -100,
187
+ )
188
+
141
189
  if self.use_sub_time_tokenization:
142
190
  time_token_indicators = torch.isin(batch["input_ids"], self.time_tokens)
143
191
  masked_tokens = batch["input_ids"].clone()
@@ -167,10 +215,130 @@ class CehrGptDataCollator:
167
215
  )
168
216
  )
169
217
 
218
+ if self.include_motor_time_to_event:
219
+ examples_with_motor_tte = [
220
+ self.create_time_to_event_labels(_) for _ in examples
221
+ ]
222
+ batch_motor_time_to_event_vectors = [
223
+ self._try_reverse_tensor(
224
+ self._convert_to_tensor(example["time_to_event_vectors"])
225
+ )
226
+ for example in examples_with_motor_tte
227
+ ]
228
+ batch_motor_event_indicators = [
229
+ self._try_reverse_tensor(
230
+ self._convert_to_tensor(example["event_indicators"])
231
+ )
232
+ for example in examples_with_motor_tte
233
+ ]
234
+ batch_motor_time_to_event_to_include = [
235
+ self._try_reverse_tensor(
236
+ self._convert_to_tensor(example["time_to_event_to_include"])
237
+ )
238
+ for example in examples_with_motor_tte
239
+ ]
240
+ batch_motor_time_indicators = [
241
+ self._try_reverse_tensor(
242
+ self._convert_to_tensor(example["time_indicators"])
243
+ )
244
+ for example in examples_with_motor_tte
245
+ ]
246
+
247
+ batch_motor_time_to_event_vectors = torch.concat(
248
+ batch_motor_time_to_event_vectors, dim=0
249
+ ).to(torch.float32)
250
+
251
+ # If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
252
+ # we only create the labels when any example has more than one visit
253
+ if batch_motor_time_to_event_vectors.dim() <= 1:
254
+ LOG.warning(
255
+ "There are no MOTOR TTE labels generated for this batch "
256
+ "because every example in this batch only contains one visit."
257
+ )
258
+ else:
259
+ batch_size = len(examples)
260
+ length, num_time_pieces, motor_tte_vocab_size = (
261
+ batch_motor_time_to_event_vectors.shape
262
+ )
263
+ padded_length = batch_size - length % batch_size
264
+ batch["motor_time_to_event_vectors"] = (
265
+ torch.concat(
266
+ [
267
+ batch_motor_time_to_event_vectors,
268
+ torch.full(
269
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
270
+ 0.0,
271
+ ),
272
+ ],
273
+ dim=0,
274
+ )
275
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
276
+ .to(torch.float32)
277
+ )
278
+
279
+ # Motor event indicators that indicate there is an event occurred in this time interval
280
+ batch_motor_event_indicators = torch.concat(
281
+ batch_motor_event_indicators, dim=0
282
+ ).to(torch.bool)
283
+ batch["motor_event_indicators"] = (
284
+ torch.concat(
285
+ [
286
+ batch_motor_event_indicators,
287
+ torch.full(
288
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
289
+ False,
290
+ ),
291
+ ],
292
+ dim=0,
293
+ )
294
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
295
+ .to(torch.bool)
296
+ )
297
+
298
+ # Input to indicate whether the visit should be included for TTE predictions
299
+ batch_motor_time_to_event_to_include = torch.concat(
300
+ batch_motor_time_to_event_to_include, dim=0
301
+ ).to(torch.bool)
302
+ batch["motor_time_to_event_to_include"] = (
303
+ torch.concat(
304
+ [
305
+ batch_motor_time_to_event_to_include,
306
+ torch.full((padded_length,), False),
307
+ ],
308
+ dim=0,
309
+ ).to(torch.bool)
310
+ ).reshape((batch_size, -1))
311
+
312
+ # Motor time indicators that indicate whether there are neither clinical events nor censor events
313
+ batch_motor_time_indicators = torch.concat(
314
+ batch_motor_time_indicators, dim=0
315
+ ).to(torch.bool)
316
+ batch["motor_time_indicators"] = (
317
+ torch.concat(
318
+ [
319
+ batch_motor_time_indicators,
320
+ torch.full(
321
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
322
+ False,
323
+ ),
324
+ ],
325
+ dim=0,
326
+ )
327
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
328
+ .to(torch.bool)
329
+ )
330
+
331
+ batch["motor_end_index"] = torch.concat(
332
+ [
333
+ torch.full((length, 1), 1, dtype=torch.int32),
334
+ torch.full((padded_length, 1), 0, dtype=torch.int32),
335
+ ]
336
+ ).reshape((batch_size, -1))
337
+
170
338
  if self.include_values:
171
339
  batch_value_indicators = [
172
340
  self._try_reverse_tensor(
173
- self._convert_to_tensor(example["value_indicators"])
341
+ self._convert_to_tensor(example["value_indicators"]).to(torch.bool)
174
342
  )
175
343
  for example in examples
176
344
  ]
@@ -178,7 +346,6 @@ class CehrGptDataCollator:
178
346
  self._try_reverse_tensor(self._convert_to_tensor(example["values"]))
179
347
  for example in examples
180
348
  ]
181
-
182
349
  batch["value_indicators"] = self._try_reverse_tensor(
183
350
  pad_sequence(
184
351
  batch_value_indicators, batch_first=True, padding_value=False
@@ -200,44 +367,248 @@ class CehrGptDataCollator:
200
367
  batch["value_indicators"], batch["values"].clone(), -100
201
368
  )
202
369
 
370
+ bz = len(examples)
203
371
  if "person_id" in examples[0]:
204
- batch["person_id"] = torch.cat(
205
- [
206
- self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
207
- for example in examples
208
- ],
209
- dim=0,
210
- ).to(torch.int32)
372
+ batch["person_id"] = (
373
+ torch.cat(
374
+ [
375
+ self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
376
+ for example in examples
377
+ ],
378
+ dim=0,
379
+ )
380
+ .to(torch.int32)
381
+ .reshape(bz, -1)
382
+ )
211
383
 
212
384
  if "index_date" in examples[0]:
213
385
  batch["index_date"] = torch.cat(
214
386
  [
215
- self._convert_to_tensor(example["index_date"]).reshape(-1, 1)
387
+ torch.tensor(example["index_date"], dtype=torch.float64).reshape(
388
+ -1, 1
389
+ )
216
390
  for example in examples
217
391
  ],
218
392
  dim=0,
219
- ).to(torch.float32)
393
+ ).reshape(bz, -1)
220
394
 
221
395
  if "age_at_index" in examples[0]:
222
- batch["age_at_index"] = torch.cat(
223
- [
224
- self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
225
- for example in examples
226
- ],
227
- dim=0,
228
- ).to(torch.float32)
396
+ batch["age_at_index"] = (
397
+ torch.cat(
398
+ [
399
+ self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
400
+ for example in examples
401
+ ],
402
+ dim=0,
403
+ )
404
+ .to(torch.float32)
405
+ .reshape(bz, -1)
406
+ )
229
407
 
230
408
  if "classifier_label" in examples[0]:
231
- batch["classifier_label"] = torch.cat(
232
- [
233
- self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1)
234
- for example in examples
235
- ],
236
- dim=0,
237
- ).to(torch.float32)
409
+ batch["classifier_label"] = (
410
+ torch.cat(
411
+ [
412
+ self._convert_to_tensor(example["classifier_label"]).reshape(
413
+ -1, 1
414
+ )
415
+ for example in examples
416
+ ],
417
+ dim=0,
418
+ )
419
+ .to(torch.float32)
420
+ .reshape(bz, -1)
421
+ )
238
422
 
239
423
  return batch
240
424
 
425
+ def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
426
+ """
427
+ Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
428
+
429
+ Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
430
+ to the occurrence of future motor-related events.
431
+
432
+ Args:
433
+ record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
434
+ This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
435
+
436
+ Returns:
437
+ Dict[str, Any]: The updated input record with added keys:
438
+ - "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
439
+ - "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
440
+ """
441
+ input_ids = record["input_ids"]
442
+ sample_packing = getattr(self, "sample_packing", False)
443
+
444
+ if isinstance(input_ids, torch.Tensor):
445
+ input_ids = input_ids.detach().tolist()
446
+
447
+ # This potentially contains packed samples, we need to handle that
448
+ packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
449
+ pad_indices = []
450
+ if sample_packing:
451
+ # We start from the first index
452
+ for i in range(len(packed_concept_ids)):
453
+ if packed_concept_ids[i] == self.tokenizer.pad_token:
454
+ # If we encounter consecutive pads, we should break out of the loop
455
+ if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
456
+ break
457
+ pad_indices.append(i)
458
+
459
+ # If we did not find a pad, that means the whole sequence belongs to one sample
460
+ if len(pad_indices) == 0:
461
+ pad_indices.append(len(packed_concept_ids))
462
+
463
+ timepiece_time_to_event_vectors = []
464
+ timepiece_event_indicators = []
465
+ timepiece_indicators = []
466
+ time_to_event_to_includes = []
467
+
468
+ for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
469
+ concept_ids = packed_concept_ids[start_index:end_index]
470
+ if concept_ids[0] == self.tokenizer.pad_token:
471
+ concept_ids.pop(0)
472
+ time_to_event_vectors = []
473
+ global_event_indicators = []
474
+
475
+ # First collect TTE data in reverse chronological order
476
+ censor_times = []
477
+ time_to_event_data: List[Dict[str, int]] = []
478
+ time_to_event_dict: Dict[str, int] = {}
479
+ time_to_event_to_include: List[bool] = []
480
+ next_future_visit_concepts = set()
481
+ time_interval = 0
482
+
483
+ # Reverse walk through concept_ids to calculate TTE from each [VE] point
484
+ for concept_id in reversed(concept_ids):
485
+ if is_visit_end(concept_id):
486
+ # Update TTE for existing concepts, or add new ones seen in this visit
487
+ for existing_concept_id in list(time_to_event_dict.keys()):
488
+ if existing_concept_id in next_future_visit_concepts:
489
+ time_to_event_dict[existing_concept_id] = time_interval
490
+ else:
491
+ time_to_event_dict[existing_concept_id] += time_interval
492
+
493
+ for next_concept_id in next_future_visit_concepts:
494
+ if next_concept_id not in time_to_event_dict:
495
+ time_to_event_dict[next_concept_id] = time_interval
496
+
497
+ # If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
498
+ # previous visit
499
+ time_to_event_to_include.append(time_interval > 0)
500
+ time_to_event_data.append(copy.deepcopy(time_to_event_dict))
501
+ # Record the censor time at the end of the visit
502
+ if censor_times:
503
+ censor_times.append(censor_times[-1] + time_interval)
504
+ else:
505
+ censor_times.append(time_interval)
506
+ time_interval = 0
507
+ next_future_visit_concepts.clear()
508
+
509
+ elif is_att_token(concept_id):
510
+ time_interval += extract_time_interval_in_days(concept_id)
511
+
512
+ elif self.tokenizer.is_motor_time_to_event_code(concept_id):
513
+ next_future_visit_concepts.add(concept_id)
514
+
515
+ if len(time_to_event_data) == 0:
516
+ LOG.info(
517
+ "Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
518
+ "It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
519
+ len(concept_ids),
520
+ concept_ids[-10:],
521
+ )
522
+ continue
523
+
524
+ # Reverse back to chronological order for final labels
525
+ time_to_event_data.reverse()
526
+ censor_times.reverse()
527
+ time_to_event_to_include.reverse()
528
+
529
+ for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
530
+ time_to_event_vector = np.full(
531
+ self.tokenizer.motor_tte_vocab_size,
532
+ fill_value=censor_time,
533
+ dtype=np.int32,
534
+ )
535
+ event_indicator = np.zeros(
536
+ self.tokenizer.motor_tte_vocab_size,
537
+ dtype=np.int32,
538
+ )
539
+ visit_token_ids = [
540
+ self.tokenizer.get_motor_token_id(concept_id)
541
+ for concept_id in visit_tte_data.keys()
542
+ ]
543
+ visit_tte_values = list(visit_tte_data.values())
544
+
545
+ time_to_event_vector[visit_token_ids] = visit_tte_values
546
+ event_indicator[visit_token_ids] = 1 # not censored (event occurred)
547
+
548
+ time_to_event_vectors.append(time_to_event_vector)
549
+ global_event_indicators.append(event_indicator)
550
+
551
+ time_to_event_vectors = np.asarray(time_to_event_vectors)
552
+ global_event_indicators = np.asarray(global_event_indicators).astype(bool)
553
+ n_visits = len(time_to_event_vectors)
554
+
555
+ timepiece_time_to_event_vector = np.full(
556
+ (
557
+ self.motor_num_time_pieces,
558
+ n_visits,
559
+ self.tokenizer.motor_tte_vocab_size,
560
+ ),
561
+ fill_value=0,
562
+ dtype=np.int32,
563
+ )
564
+ timepiece_event_indicator = np.zeros(
565
+ (
566
+ self.motor_num_time_pieces,
567
+ n_visits,
568
+ self.tokenizer.motor_tte_vocab_size,
569
+ ),
570
+ dtype=bool,
571
+ )
572
+ timepiece_indicator = np.zeros(
573
+ (
574
+ self.motor_num_time_pieces,
575
+ n_visits,
576
+ self.tokenizer.motor_tte_vocab_size,
577
+ ),
578
+ dtype=bool,
579
+ )
580
+
581
+ # Putting the event time and censor time into the corresponding time bins
582
+ for bin_num in range(self.motor_num_time_pieces):
583
+ start = self.motor_time_interval * bin_num
584
+ end = self.motor_time_interval * (bin_num + 1)
585
+ time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
586
+ timepiece_time_to_event_vector[bin_num] = time_in_bin
587
+ event_indicator = (
588
+ global_event_indicators
589
+ & (start <= time_to_event_vectors)
590
+ & (time_to_event_vectors < end)
591
+ )
592
+ timepiece_event_indicator[bin_num] = event_indicator
593
+ timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
594
+
595
+ timepiece_time_to_event_vectors.append(
596
+ timepiece_time_to_event_vector.swapaxes(0, 1)
597
+ )
598
+ timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
599
+ timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
600
+ time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
601
+
602
+ record["time_to_event_vectors"] = np.concatenate(
603
+ timepiece_time_to_event_vectors, axis=0
604
+ )
605
+ record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
606
+ record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
607
+ record["time_to_event_to_include"] = np.concatenate(
608
+ time_to_event_to_includes, axis=0
609
+ )
610
+ return record
611
+
241
612
  def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
242
613
 
243
614
  if not self.shuffle_records:
@@ -273,53 +644,82 @@ class CehrGptDataCollator:
273
644
  record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
274
645
  return record
275
646
 
276
- def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
647
+ def generate_start_end_index(
648
+ self,
649
+ record: Dict[str, Any],
650
+ sample_packing: bool,
651
+ max_length_allowed: Optional[int] = None,
652
+ ) -> Dict[str, Any]:
277
653
  """Adding the start and end indices to extract a portion of the patient sequence."""
278
654
  # concept_ids will be used to for time to event predictions and identifying the visit starts
655
+ max_length_allowed = (
656
+ self.max_length if max_length_allowed is None else max_length_allowed
657
+ )
279
658
  input_ids = record["input_ids"]
280
659
  if isinstance(input_ids, torch.Tensor):
281
660
  input_ids = input_ids.detach().tolist()
282
661
  concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
283
662
  seq_length = len(record["input_ids"])
284
- new_max_length = self.max_length - 1 # Subtract one for the [END] token
663
+
664
+ # Subtract one for the [END] token when sample_packing is not enabled
665
+ new_max_length = (
666
+ max_length_allowed if sample_packing else max_length_allowed - 1
667
+ )
668
+
669
+ if self.include_ttv_prediction:
670
+ record["time_to_visits"] = torch.concat(
671
+ [self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
672
+ )
673
+
674
+ # If linear token exists, we will use it, otherwise we default to the OOV token
675
+ linear_token_id = (
676
+ self.tokenizer.linear_token_id
677
+ if self.tokenizer.linear_token_id
678
+ else self.tokenizer.oov_token_id
679
+ )
680
+ eos_token = (
681
+ linear_token_id
682
+ if self.add_linear_prob_token
683
+ else self.tokenizer.end_token_id
684
+ )
285
685
 
286
686
  # Return the record directly if the actual sequence length is less than the max sequence
287
687
  if seq_length <= new_max_length:
288
- record["input_ids"] = torch.concat(
289
- [
290
- self._convert_to_tensor(record["input_ids"]),
291
- self._convert_to_tensor([self.tokenizer.end_token_id]),
292
- ]
293
- )
294
- if self.include_values:
295
- record["value_indicators"] = torch.concat(
296
- [
297
- self._convert_to_tensor(record["value_indicators"]),
298
- self._convert_to_tensor([False]),
299
- ]
300
- ).to(torch.bool)
301
- record["values"] = torch.concat(
302
- [
303
- self._convert_to_tensor(record["values"]),
304
- self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
305
- ]
306
- )
307
- if self.include_ttv_prediction:
308
- record["time_to_visits"] = torch.concat(
688
+ if not sample_packing:
689
+ record["input_ids"] = torch.concat(
309
690
  [
310
- self._convert_to_tensor(
311
- self._convert_time_to_event(concept_ids)
312
- ),
313
- self._convert_to_tensor([-100.0]),
691
+ self._convert_to_tensor(record["input_ids"]),
692
+ self._convert_to_tensor([eos_token]),
314
693
  ]
315
694
  )
316
-
695
+ if self.include_values:
696
+ record["value_indicators"] = torch.concat(
697
+ [
698
+ self._convert_to_tensor(record["value_indicators"]),
699
+ self._convert_to_tensor([False]),
700
+ ]
701
+ ).to(torch.bool)
702
+ record["values"] = torch.concat(
703
+ [
704
+ self._convert_to_tensor(record["values"]),
705
+ self._convert_to_tensor(
706
+ [self.tokenizer.pad_value_token_id]
707
+ ),
708
+ ]
709
+ )
710
+ if self.include_ttv_prediction:
711
+ record["time_to_visits"] = torch.concat(
712
+ [
713
+ record["time_to_visits"],
714
+ self._convert_to_tensor([-100.0]),
715
+ ]
716
+ )
317
717
  return record
318
718
 
319
719
  if self.pretraining:
320
720
  # There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
321
721
  # prompt depending on the new starting point
322
- if random.random() < 0.5:
722
+ if random.random() < 0.5 and not sample_packing:
323
723
  start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
324
724
  concept_ids, new_max_length
325
725
  )
@@ -347,10 +747,19 @@ class CehrGptDataCollator:
347
747
  for i in reversed(list(range(0, end_index))):
348
748
  current_token = record["input_ids"][i]
349
749
  if current_token == self.ve_token_id:
350
- end_index = i
750
+ # Plus one because slicing is right exclusive
751
+ end_index = i + 1
351
752
  break
352
753
 
353
754
  record["input_ids"] = record["input_ids"][0:end_index]
755
+
756
+ # We want to make sure we take the subset of attention_mask in sample packing if this field is available
757
+ if sample_packing and "attention_mask" in record:
758
+ record["attention_mask"] = record["attention_mask"][0:end_index]
759
+
760
+ if sample_packing and "position_ids" in record:
761
+ record["position_ids"] = record["position_ids"][0:end_index]
762
+
354
763
  if self.include_values:
355
764
  record["value_indicators"] = self._convert_to_tensor(
356
765
  record["value_indicators"][0:end_index]
@@ -364,7 +773,7 @@ class CehrGptDataCollator:
364
773
  )
365
774
  return record
366
775
  else:
367
- if self.include_demographics:
776
+ if self.include_demographics and not sample_packing:
368
777
  # We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
369
778
  demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
370
779
  concept_ids
@@ -427,6 +836,12 @@ class CehrGptDataCollator:
427
836
  current_token = record["input_ids"][i]
428
837
  if current_token == self.vs_token_id:
429
838
  record["input_ids"] = record["input_ids"][i:end_index]
839
+ if sample_packing and "attention_mask" in record:
840
+ record["attention_mask"] = record["attention_mask"][
841
+ i:end_index
842
+ ]
843
+ if sample_packing and "position_ids" in record:
844
+ record["position_ids"] = record["position_ids"][i:end_index]
430
845
  if self.include_values:
431
846
  record["value_indicators"] = record["value_indicators"][
432
847
  i:end_index
@@ -442,6 +857,12 @@ class CehrGptDataCollator:
442
857
  # We simply take the last new_max_length number of tokens from the patient sequence
443
858
  if len(record["input_ids"]) > new_max_length:
444
859
  record["input_ids"] = record["input_ids"][-new_max_length:]
860
+ if sample_packing and "attention_mask" in record:
861
+ record["attention_mask"] = record["attention_mask"][
862
+ -new_max_length:
863
+ ]
864
+ if sample_packing and "position_ids" in record:
865
+ record["position_ids"] = record["position_ids"][-new_max_length:]
445
866
  if self.include_values:
446
867
  record["value_indicators"] = record["value_indicators"][
447
868
  -new_max_length:
@@ -452,31 +873,148 @@ class CehrGptDataCollator:
452
873
  -new_max_length:
453
874
  ]
454
875
 
455
- # Finally we add the end token to the end of the sequence
456
- record["input_ids"] = torch.concat(
457
- [
458
- self._convert_to_tensor(record["input_ids"]),
459
- self._convert_to_tensor([self.tokenizer.end_token_id]),
460
- ]
461
- )
462
- if self.include_values:
463
- record["value_indicators"] = torch.concat(
876
+ if not sample_packing:
877
+ # Finally we add the end token to the end of the sequence
878
+ record["input_ids"] = torch.concat(
464
879
  [
465
- self._convert_to_tensor(record["value_indicators"]),
466
- self._convert_to_tensor([False]),
467
- ]
468
- ).to(torch.bool)
469
- record["values"] = torch.concat(
470
- [
471
- self._convert_to_tensor(record["values"]),
472
- self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
473
- ]
474
- )
475
- if self.include_ttv_prediction:
476
- record["time_to_visits"] = torch.concat(
477
- [
478
- record["time_to_visits"],
479
- self._convert_to_tensor([-100.0]),
880
+ self._convert_to_tensor(record["input_ids"]),
881
+ self._convert_to_tensor([eos_token]),
480
882
  ]
481
883
  )
884
+ if self.include_values:
885
+ record["value_indicators"] = torch.concat(
886
+ [
887
+ self._convert_to_tensor(record["value_indicators"]),
888
+ self._convert_to_tensor([False]),
889
+ ]
890
+ ).to(torch.bool)
891
+ record["values"] = torch.concat(
892
+ [
893
+ self._convert_to_tensor(record["values"]),
894
+ self._convert_to_tensor(
895
+ [self.tokenizer.pad_value_token_id]
896
+ ),
897
+ ]
898
+ )
899
+ if self.include_ttv_prediction:
900
+ record["time_to_visits"] = torch.concat(
901
+ [
902
+ record["time_to_visits"],
903
+ self._convert_to_tensor([-100.0]),
904
+ ]
905
+ )
482
906
  return record
907
+
908
+
909
+ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
910
+ def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
911
+ self.max_tokens_per_batch = max_tokens
912
+ self.max_position_embeddings = max_position_embeddings
913
+ self.sample_packing = True
914
+ self.add_end_token_in_sample_packing = kwargs.pop(
915
+ "add_end_token_in_sample_packing", False
916
+ )
917
+ super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
918
+
919
+ def __call__(self, examples):
920
+ current_input_ids = []
921
+ current_attention_mask = []
922
+ current_position_ids = []
923
+ current_value_indicators = []
924
+ current_values = []
925
+
926
+ # Demographics
927
+ current_person_ids = []
928
+ current_index_dates = []
929
+
930
+ # Binary classification inputs
931
+ current_ages = []
932
+ current_labels = []
933
+
934
+ for idx, example in enumerate(examples):
935
+
936
+ # We only add an end token if the patient sequence could fit in the entire context window
937
+ add_end_token = (
938
+ len(example["input_ids"]) <= self.max_position_embeddings
939
+ and self.add_end_token_in_sample_packing
940
+ )
941
+ # If the sample length exceeds the model's capacity, truncate this example
942
+ if len(example["input_ids"]) > self.max_position_embeddings:
943
+ example = self.generate_start_end_index(
944
+ example, False, self.max_position_embeddings
945
+ )
946
+
947
+ add_eos_token = add_end_token | self.add_linear_prob_token
948
+ additional_tokens = []
949
+ if add_end_token:
950
+ additional_tokens.append(self.tokenizer.end_token_id)
951
+ elif self.add_linear_prob_token:
952
+ # Backward compatible
953
+ linear_prob_token_id = (
954
+ self.tokenizer.linear_token_id
955
+ if self.tokenizer.linear_token_id is not None
956
+ else self.tokenizer.oov_token_id
957
+ )
958
+ additional_tokens.append(linear_prob_token_id)
959
+ additional_tokens.append(self.tokenizer.pad_token_id)
960
+ input_ids = example["input_ids"]
961
+ # We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
962
+ # If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
963
+ # we only add [PAD] token to the end of the sequence because it's not finished
964
+ current_input_ids.extend(list(input_ids) + additional_tokens)
965
+ current_attention_mask.extend(
966
+ np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
967
+ )
968
+ num_tokens_to_pad = 1 + int(add_eos_token)
969
+ current_position_ids.extend(
970
+ np.clip(
971
+ list(range(len(input_ids) + num_tokens_to_pad)),
972
+ 0,
973
+ self.max_position_embeddings - 1,
974
+ )
975
+ )
976
+ if self.include_values:
977
+ current_value_indicators.extend(
978
+ list(example["value_indicators"]) + [False] * num_tokens_to_pad
979
+ )
980
+ current_values.extend(
981
+ list(example["values"])
982
+ + [self.tokenizer.pad_value_token_id] * num_tokens_to_pad
983
+ )
984
+
985
+ if "person_id" in example:
986
+ current_person_ids.append(example["person_id"])
987
+
988
+ if "index_date" in example:
989
+ current_index_dates.append(example["index_date"])
990
+
991
+ if "age_at_index" in example:
992
+ current_ages.append(example["age_at_index"])
993
+
994
+ if "classifier_label" in example:
995
+ current_labels.append(example["classifier_label"])
996
+
997
+ assert len(current_input_ids) <= self.max_tokens_per_batch, (
998
+ f"The total number of tokens in the packed sequence should be less than {self.max_tokens_per_batch}\n"
999
+ f"But the total number of tokens is: {len(current_input_ids)}"
1000
+ )
1001
+ packed_example = {
1002
+ "input_ids": current_input_ids,
1003
+ "attention_mask": current_attention_mask,
1004
+ "position_ids": current_position_ids,
1005
+ }
1006
+ if self.include_values:
1007
+ packed_example.update({"value_indicators": current_value_indicators})
1008
+ packed_example.update({"values": current_values})
1009
+
1010
+ if current_labels:
1011
+ packed_example.update(
1012
+ {
1013
+ "person_id": current_person_ids,
1014
+ "index_date": current_index_dates,
1015
+ "age_at_index": current_ages,
1016
+ "classifier_label": current_labels,
1017
+ }
1018
+ )
1019
+
1020
+ return super().__call__([packed_example])