cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  5. cehrgpt/data/sample_packing_sampler.py +36 -6
  6. cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  7. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  8. cehrgpt/generation/omop_converter_batch.py +32 -2
  9. cehrgpt/gpt_utils.py +20 -2
  10. cehrgpt/models/config.py +25 -0
  11. cehrgpt/models/hf_cehrgpt.py +244 -39
  12. cehrgpt/models/hf_modeling_outputs.py +1 -0
  13. cehrgpt/models/special_tokens.py +1 -0
  14. cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  15. cehrgpt/runners/data_utils.py +131 -5
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +6 -7
  20. cehrgpt/runners/sample_packing_trainer.py +17 -0
  21. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  22. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  23. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  25. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
  26. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
  27. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
  28. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
  29. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,28 @@
1
+ import copy
1
2
  import random
2
- from typing import Any, Dict, Optional
3
+ from typing import Any, Dict, List, Optional
3
4
 
4
5
  import numpy as np
5
6
  import torch
6
7
  from torch.nn.utils.rnn import pad_sequence
8
+ from transformers.utils import logging
7
9
 
8
10
  from cehrgpt.gpt_utils import (
9
11
  DEMOGRAPHIC_PROMPT_SIZE,
10
12
  collect_demographic_prompts_at_visits,
11
13
  extract_time_interval_in_days,
14
+ extract_time_interval_in_hours,
12
15
  is_att_token,
13
16
  is_inpatient_att_token,
17
+ is_inpatient_hour_token,
18
+ is_visit_end,
14
19
  random_slice_gpt_sequence,
15
20
  )
16
21
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
17
22
 
23
+ TIME_TO_EVENT_MAX_TIME = 3650
18
24
  INPATIENT_STAY_DURATION_LIMIT = 30
25
+ LOG = logging.get_logger("transformers")
19
26
 
20
27
 
21
28
  class CehrGptDataCollator:
@@ -27,20 +34,18 @@ class CehrGptDataCollator:
27
34
  include_values: bool = False,
28
35
  include_ttv_prediction: bool = False,
29
36
  use_sub_time_tokenization: bool = False,
37
+ include_motor_time_to_event: bool = False,
38
+ motor_tte_vocab_size: int = 0,
39
+ motor_num_time_pieces: int = 8,
30
40
  pretraining: bool = True,
31
41
  include_demographics: bool = False,
42
+ add_linear_prob_token: bool = False,
32
43
  ):
33
44
  self.tokenizer = tokenizer
34
45
  self.max_length = max_length
35
- # Pre-compute these so we can use them later on
36
- # We used VS for the historical data, currently, we use the new [VS] for the newer data
37
- # so we need to check both cases.
38
- self.vs_token_id = tokenizer._convert_token_to_id("VS")
39
- if self.vs_token_id == tokenizer._oov_token_id:
40
- self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
41
- self.ve_token_id = tokenizer._convert_token_to_id("VE")
42
- if self.ve_token_id == tokenizer._oov_token_id:
43
- self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
46
+
47
+ self.vs_token_id = tokenizer.vs_token_id
48
+ self.ve_token_id = tokenizer.ve_token_id
44
49
 
45
50
  self.shuffle_records = shuffle_records
46
51
  self.include_values = include_values
@@ -48,6 +53,20 @@ class CehrGptDataCollator:
48
53
  self.use_sub_time_tokenization = use_sub_time_tokenization
49
54
  self.pretraining = pretraining
50
55
  self.include_demographics = include_demographics
56
+ self.add_linear_prob_token = add_linear_prob_token
57
+
58
+ # MOTOR TTE configuration
59
+ if include_motor_time_to_event:
60
+ assert motor_tte_vocab_size > 0, (
61
+ f"motor_tte_vocab_size must be greater than 0 "
62
+ f"when include_motor_time_to_event is set to True. "
63
+ f"But motor_tte_vocab_size: {motor_tte_vocab_size} is provided"
64
+ )
65
+
66
+ self.include_motor_time_to_event = include_motor_time_to_event
67
+ self.motor_tte_vocab_size = motor_tte_vocab_size
68
+ self.motor_num_time_pieces = motor_num_time_pieces
69
+ self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
51
70
 
52
71
  if self.use_sub_time_tokenization:
53
72
  token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
@@ -88,6 +107,8 @@ class CehrGptDataCollator:
88
107
  ):
89
108
  return -100
90
109
  return time_to_visit
110
+ elif is_inpatient_hour_token(c):
111
+ return extract_time_interval_in_hours(c) / 24
91
112
  return -100
92
113
  except ValueError:
93
114
  return -100
@@ -95,8 +116,8 @@ class CehrGptDataCollator:
95
116
  return [float(default_value(_)) for _ in concept_ids]
96
117
 
97
118
  def __call__(self, examples):
98
-
99
- examples = [self.generate_start_end_index(_) for _ in examples]
119
+ sample_packing = getattr(self, "sample_packing", False)
120
+ examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
100
121
  examples = [self.random_sort(_) for _ in examples]
101
122
  batch = {}
102
123
 
@@ -141,6 +162,22 @@ class CehrGptDataCollator:
141
162
  f"batch['input_ids']: {batch['input_ids']} "
142
163
  )
143
164
 
165
+ if "epoch_times" in examples[0]:
166
+ batch_epoch_times = [
167
+ self._try_reverse_tensor(
168
+ self._convert_to_tensor(example["epoch_times"])
169
+ )
170
+ for example in examples
171
+ ]
172
+ # Pad sequences to the max length in the batch
173
+ batch["epoch_times"] = self._try_reverse_tensor(
174
+ pad_sequence(
175
+ batch_epoch_times,
176
+ batch_first=True,
177
+ padding_value=0,
178
+ ).to(torch.float32)
179
+ )
180
+
144
181
  if "position_ids" in examples[0]:
145
182
  batch_position_ids = [
146
183
  self._try_reverse_tensor(
@@ -153,7 +190,7 @@ class CehrGptDataCollator:
153
190
  pad_sequence(
154
191
  batch_position_ids,
155
192
  batch_first=True,
156
- padding_value=self.max_length,
193
+ padding_value=0,
157
194
  ).to(torch.int64)
158
195
  )
159
196
 
@@ -194,6 +231,126 @@ class CehrGptDataCollator:
194
231
  )
195
232
  )
196
233
 
234
+ if self.include_motor_time_to_event:
235
+ examples_with_motor_tte = [
236
+ self.create_time_to_event_labels(_) for _ in examples
237
+ ]
238
+ batch_motor_time_to_event_vectors = [
239
+ self._try_reverse_tensor(
240
+ self._convert_to_tensor(example["time_to_event_vectors"])
241
+ )
242
+ for example in examples_with_motor_tte
243
+ ]
244
+ batch_motor_event_indicators = [
245
+ self._try_reverse_tensor(
246
+ self._convert_to_tensor(example["event_indicators"])
247
+ )
248
+ for example in examples_with_motor_tte
249
+ ]
250
+ batch_motor_time_to_event_to_include = [
251
+ self._try_reverse_tensor(
252
+ self._convert_to_tensor(example["time_to_event_to_include"])
253
+ )
254
+ for example in examples_with_motor_tte
255
+ ]
256
+ batch_motor_time_indicators = [
257
+ self._try_reverse_tensor(
258
+ self._convert_to_tensor(example["time_indicators"])
259
+ )
260
+ for example in examples_with_motor_tte
261
+ ]
262
+
263
+ batch_motor_time_to_event_vectors = torch.concat(
264
+ batch_motor_time_to_event_vectors, dim=0
265
+ ).to(torch.float32)
266
+
267
+ # If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
268
+ # we only create the labels when any example has more than one visit
269
+ if batch_motor_time_to_event_vectors.dim() <= 1:
270
+ LOG.warning(
271
+ "There are no MOTOR TTE labels generated for this batch "
272
+ "because every example in this batch only contains one visit."
273
+ )
274
+ else:
275
+ batch_size = len(examples)
276
+ length, num_time_pieces, motor_tte_vocab_size = (
277
+ batch_motor_time_to_event_vectors.shape
278
+ )
279
+ padded_length = batch_size - length % batch_size
280
+ batch["motor_time_to_event_vectors"] = (
281
+ torch.concat(
282
+ [
283
+ batch_motor_time_to_event_vectors,
284
+ torch.full(
285
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
286
+ 0.0,
287
+ ),
288
+ ],
289
+ dim=0,
290
+ )
291
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
292
+ .to(torch.float32)
293
+ )
294
+
295
+ # Motor event indicators that indicate there is an event occurred in this time interval
296
+ batch_motor_event_indicators = torch.concat(
297
+ batch_motor_event_indicators, dim=0
298
+ ).to(torch.bool)
299
+ batch["motor_event_indicators"] = (
300
+ torch.concat(
301
+ [
302
+ batch_motor_event_indicators,
303
+ torch.full(
304
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
305
+ False,
306
+ ),
307
+ ],
308
+ dim=0,
309
+ )
310
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
311
+ .to(torch.bool)
312
+ )
313
+
314
+ # Input to indicate whether the visit should be included for TTE predictions
315
+ batch_motor_time_to_event_to_include = torch.concat(
316
+ batch_motor_time_to_event_to_include, dim=0
317
+ ).to(torch.bool)
318
+ batch["motor_time_to_event_to_include"] = (
319
+ torch.concat(
320
+ [
321
+ batch_motor_time_to_event_to_include,
322
+ torch.full((padded_length,), False),
323
+ ],
324
+ dim=0,
325
+ ).to(torch.bool)
326
+ ).reshape((batch_size, -1))
327
+
328
+ # Motor time indicators that indicate whether there are neither clinical events nor censor events
329
+ batch_motor_time_indicators = torch.concat(
330
+ batch_motor_time_indicators, dim=0
331
+ ).to(torch.bool)
332
+ batch["motor_time_indicators"] = (
333
+ torch.concat(
334
+ [
335
+ batch_motor_time_indicators,
336
+ torch.full(
337
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
338
+ False,
339
+ ),
340
+ ],
341
+ dim=0,
342
+ )
343
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
344
+ .to(torch.bool)
345
+ )
346
+
347
+ batch["motor_end_index"] = torch.concat(
348
+ [
349
+ torch.full((length, 1), 1, dtype=torch.int32),
350
+ torch.full((padded_length, 1), 0, dtype=torch.int32),
351
+ ]
352
+ ).reshape((batch_size, -1))
353
+
197
354
  if self.include_values:
198
355
  batch_value_indicators = [
199
356
  self._try_reverse_tensor(
@@ -281,6 +438,193 @@ class CehrGptDataCollator:
281
438
 
282
439
  return batch
283
440
 
441
+ def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
442
+ """
443
+ Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
444
+
445
+ Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
446
+ to the occurrence of future motor-related events.
447
+
448
+ Args:
449
+ record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
450
+ This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
451
+
452
+ Returns:
453
+ Dict[str, Any]: The updated input record with added keys:
454
+ - "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
455
+ - "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
456
+ """
457
+ input_ids = record["input_ids"]
458
+ sample_packing = getattr(self, "sample_packing", False)
459
+
460
+ if isinstance(input_ids, torch.Tensor):
461
+ input_ids = input_ids.detach().tolist()
462
+
463
+ # This potentially contains packed samples, we need to handle that
464
+ packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
465
+ pad_indices = []
466
+ if sample_packing:
467
+ # We start from the first index
468
+ for i in range(len(packed_concept_ids)):
469
+ if packed_concept_ids[i] == self.tokenizer.pad_token:
470
+ # If we encounter consecutive pads, we should break out of the loop
471
+ if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
472
+ break
473
+ pad_indices.append(i)
474
+
475
+ # If we did not find a pad, that means the whole sequence belongs to one sample
476
+ if len(pad_indices) == 0:
477
+ pad_indices.append(len(packed_concept_ids))
478
+
479
+ timepiece_time_to_event_vectors = []
480
+ timepiece_event_indicators = []
481
+ timepiece_indicators = []
482
+ time_to_event_to_includes = []
483
+
484
+ for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
485
+ concept_ids = packed_concept_ids[start_index:end_index]
486
+ if concept_ids[0] == self.tokenizer.pad_token:
487
+ concept_ids.pop(0)
488
+ time_to_event_vectors = []
489
+ global_event_indicators = []
490
+
491
+ # First collect TTE data in reverse chronological order
492
+ censor_times = []
493
+ time_to_event_data: List[Dict[str, int]] = []
494
+ time_to_event_dict: Dict[str, int] = {}
495
+ time_to_event_to_include: List[bool] = []
496
+ next_future_visit_concepts = set()
497
+ time_interval = 0
498
+
499
+ # Reverse walk through concept_ids to calculate TTE from each [VE] point
500
+ for concept_id in reversed(concept_ids):
501
+ if is_visit_end(concept_id):
502
+ # Update TTE for existing concepts, or add new ones seen in this visit
503
+ for existing_concept_id in list(time_to_event_dict.keys()):
504
+ if existing_concept_id in next_future_visit_concepts:
505
+ time_to_event_dict[existing_concept_id] = time_interval
506
+ else:
507
+ time_to_event_dict[existing_concept_id] += time_interval
508
+
509
+ for next_concept_id in next_future_visit_concepts:
510
+ if next_concept_id not in time_to_event_dict:
511
+ time_to_event_dict[next_concept_id] = time_interval
512
+
513
+ # If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
514
+ # previous visit
515
+ time_to_event_to_include.append(time_interval > 0)
516
+ time_to_event_data.append(copy.deepcopy(time_to_event_dict))
517
+ # Record the censor time at the end of the visit
518
+ if censor_times:
519
+ censor_times.append(censor_times[-1] + time_interval)
520
+ else:
521
+ censor_times.append(time_interval)
522
+ time_interval = 0
523
+ next_future_visit_concepts.clear()
524
+
525
+ elif is_att_token(concept_id):
526
+ time_interval += extract_time_interval_in_days(concept_id)
527
+
528
+ elif self.tokenizer.is_motor_time_to_event_code(concept_id):
529
+ next_future_visit_concepts.add(concept_id)
530
+
531
+ if len(time_to_event_data) == 0:
532
+ LOG.info(
533
+ "Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
534
+ "It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
535
+ len(concept_ids),
536
+ concept_ids[-10:],
537
+ )
538
+ continue
539
+
540
+ # Reverse back to chronological order for final labels
541
+ time_to_event_data.reverse()
542
+ censor_times.reverse()
543
+ time_to_event_to_include.reverse()
544
+
545
+ for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
546
+ time_to_event_vector = np.full(
547
+ self.tokenizer.motor_tte_vocab_size,
548
+ fill_value=censor_time,
549
+ dtype=np.int32,
550
+ )
551
+ event_indicator = np.zeros(
552
+ self.tokenizer.motor_tte_vocab_size,
553
+ dtype=np.int32,
554
+ )
555
+ visit_token_ids = [
556
+ self.tokenizer.get_motor_token_id(concept_id)
557
+ for concept_id in visit_tte_data.keys()
558
+ ]
559
+ visit_tte_values = list(visit_tte_data.values())
560
+
561
+ time_to_event_vector[visit_token_ids] = visit_tte_values
562
+ event_indicator[visit_token_ids] = 1 # not censored (event occurred)
563
+
564
+ time_to_event_vectors.append(time_to_event_vector)
565
+ global_event_indicators.append(event_indicator)
566
+
567
+ time_to_event_vectors = np.asarray(time_to_event_vectors)
568
+ global_event_indicators = np.asarray(global_event_indicators).astype(bool)
569
+ n_visits = len(time_to_event_vectors)
570
+
571
+ timepiece_time_to_event_vector = np.full(
572
+ (
573
+ self.motor_num_time_pieces,
574
+ n_visits,
575
+ self.tokenizer.motor_tte_vocab_size,
576
+ ),
577
+ fill_value=0,
578
+ dtype=np.int32,
579
+ )
580
+ timepiece_event_indicator = np.zeros(
581
+ (
582
+ self.motor_num_time_pieces,
583
+ n_visits,
584
+ self.tokenizer.motor_tte_vocab_size,
585
+ ),
586
+ dtype=bool,
587
+ )
588
+ timepiece_indicator = np.zeros(
589
+ (
590
+ self.motor_num_time_pieces,
591
+ n_visits,
592
+ self.tokenizer.motor_tte_vocab_size,
593
+ ),
594
+ dtype=bool,
595
+ )
596
+
597
+ # Putting the event time and censor time into the corresponding time bins
598
+ for bin_num in range(self.motor_num_time_pieces):
599
+ start = self.motor_time_interval * bin_num
600
+ end = self.motor_time_interval * (bin_num + 1)
601
+ time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
602
+ timepiece_time_to_event_vector[bin_num] = time_in_bin
603
+ event_indicator = (
604
+ global_event_indicators
605
+ & (start <= time_to_event_vectors)
606
+ & (time_to_event_vectors < end)
607
+ )
608
+ timepiece_event_indicator[bin_num] = event_indicator
609
+ timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
610
+
611
+ timepiece_time_to_event_vectors.append(
612
+ timepiece_time_to_event_vector.swapaxes(0, 1)
613
+ )
614
+ timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
615
+ timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
616
+ time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
617
+
618
+ record["time_to_event_vectors"] = np.concatenate(
619
+ timepiece_time_to_event_vectors, axis=0
620
+ )
621
+ record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
622
+ record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
623
+ record["time_to_event_to_include"] = np.concatenate(
624
+ time_to_event_to_includes, axis=0
625
+ )
626
+ return record
627
+
284
628
  def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
285
629
 
286
630
  if not self.shuffle_records:
@@ -317,14 +661,16 @@ class CehrGptDataCollator:
317
661
  return record
318
662
 
319
663
  def generate_start_end_index(
320
- self, record: Dict[str, Any], max_length_allowed: Optional[int] = None
664
+ self,
665
+ record: Dict[str, Any],
666
+ sample_packing: bool,
667
+ max_length_allowed: Optional[int] = None,
321
668
  ) -> Dict[str, Any]:
322
669
  """Adding the start and end indices to extract a portion of the patient sequence."""
323
670
  # concept_ids will be used to for time to event predictions and identifying the visit starts
324
671
  max_length_allowed = (
325
672
  self.max_length if max_length_allowed is None else max_length_allowed
326
673
  )
327
- sample_packing = getattr(self, "sample_packing", False)
328
674
  input_ids = record["input_ids"]
329
675
  if isinstance(input_ids, torch.Tensor):
330
676
  input_ids = input_ids.detach().tolist()
@@ -333,7 +679,9 @@ class CehrGptDataCollator:
333
679
 
334
680
  # Subtract one for the [END] token when sample_packing is not enabled
335
681
  new_max_length = (
336
- max_length_allowed if sample_packing else max_length_allowed - 1
682
+ max_length_allowed - 1
683
+ if not sample_packing and self.pretraining
684
+ else max_length_allowed
337
685
  )
338
686
 
339
687
  if self.include_ttv_prediction:
@@ -341,15 +689,34 @@ class CehrGptDataCollator:
341
689
  [self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
342
690
  )
343
691
 
692
+ # If linear token exists, we will use it, otherwise we default to the OOV token
693
+ linear_token_id = (
694
+ self.tokenizer.linear_token_id
695
+ if self.tokenizer.linear_token_id
696
+ else self.tokenizer.oov_token_id
697
+ )
698
+ eos_token = (
699
+ linear_token_id
700
+ if self.add_linear_prob_token
701
+ else self.tokenizer.end_token_id
702
+ )
703
+
344
704
  # Return the record directly if the actual sequence length is less than the max sequence
345
705
  if seq_length <= new_max_length:
346
- if not sample_packing:
706
+ if not sample_packing and self.pretraining:
347
707
  record["input_ids"] = torch.concat(
348
708
  [
349
709
  self._convert_to_tensor(record["input_ids"]),
350
- self._convert_to_tensor([self.tokenizer.end_token_id]),
710
+ self._convert_to_tensor([eos_token]),
351
711
  ]
352
712
  )
713
+ if "epoch_times" in record:
714
+ record["epoch_times"] = torch.concat(
715
+ [
716
+ self._convert_to_tensor(record["epoch_times"]),
717
+ self._convert_to_tensor([record["epoch_times"][-1]]),
718
+ ]
719
+ )
353
720
  if self.include_values:
354
721
  record["value_indicators"] = torch.concat(
355
722
  [
@@ -372,7 +739,6 @@ class CehrGptDataCollator:
372
739
  self._convert_to_tensor([-100.0]),
373
740
  ]
374
741
  )
375
-
376
742
  return record
377
743
 
378
744
  if self.pretraining:
@@ -386,6 +752,10 @@ class CehrGptDataCollator:
386
752
  record["input_ids"] = self._convert_to_tensor(
387
753
  record["input_ids"][start_index : end_index + 1]
388
754
  )
755
+ if "epoch_times" in record:
756
+ record["epoch_times"] = self._convert_to_tensor(
757
+ record["epoch_times"][start_index : end_index + 1]
758
+ )
389
759
  if self.include_values:
390
760
  record["value_indicators"] = self._convert_to_tensor(
391
761
  record["value_indicators"][start_index : end_index + 1]
@@ -406,7 +776,8 @@ class CehrGptDataCollator:
406
776
  for i in reversed(list(range(0, end_index))):
407
777
  current_token = record["input_ids"][i]
408
778
  if current_token == self.ve_token_id:
409
- end_index = i
779
+ # Plus one because slicing is right exclusive
780
+ end_index = i + 1
410
781
  break
411
782
 
412
783
  record["input_ids"] = record["input_ids"][0:end_index]
@@ -415,6 +786,14 @@ class CehrGptDataCollator:
415
786
  if sample_packing and "attention_mask" in record:
416
787
  record["attention_mask"] = record["attention_mask"][0:end_index]
417
788
 
789
+ if sample_packing and "position_ids" in record:
790
+ record["position_ids"] = record["position_ids"][0:end_index]
791
+
792
+ if "epoch_times" in record:
793
+ record["epoch_times"] = self._convert_to_tensor(
794
+ record["epoch_times"][0:end_index]
795
+ )
796
+
418
797
  if self.include_values:
419
798
  record["value_indicators"] = self._convert_to_tensor(
420
799
  record["value_indicators"][0:end_index]
@@ -447,6 +826,17 @@ class CehrGptDataCollator:
447
826
  ),
448
827
  ]
449
828
  )
829
+ if "epoch_times" in record:
830
+ record["epoch_times"] = torch.concat(
831
+ [
832
+ torch.zeros(
833
+ [record["epoch_times"][0]], dtype=torch.float32
834
+ ),
835
+ self._convert_to_tensor(
836
+ record["epoch_times"][token_index:seq_length]
837
+ ),
838
+ ]
839
+ )
450
840
  if self.include_values:
451
841
  record["value_indicators"] = torch.concat(
452
842
  [
@@ -485,7 +875,7 @@ class CehrGptDataCollator:
485
875
  )
486
876
  break
487
877
  else:
488
- start_index = seq_length - new_max_length
878
+ start_index = max(seq_length - new_max_length, 0)
489
879
  end_index = seq_length
490
880
  for i in range(start_index, end_index):
491
881
  current_token = record["input_ids"][i]
@@ -495,6 +885,13 @@ class CehrGptDataCollator:
495
885
  record["attention_mask"] = record["attention_mask"][
496
886
  i:end_index
497
887
  ]
888
+ if sample_packing and "position_ids" in record:
889
+ record["position_ids"] = record["position_ids"][i:end_index]
890
+
891
+ if "epoch_times" in record:
892
+ record["epoch_times"] = self._convert_to_tensor(
893
+ record["epoch_times"][i:end_index]
894
+ )
498
895
  if self.include_values:
499
896
  record["value_indicators"] = record["value_indicators"][
500
897
  i:end_index
@@ -514,6 +911,12 @@ class CehrGptDataCollator:
514
911
  record["attention_mask"] = record["attention_mask"][
515
912
  -new_max_length:
516
913
  ]
914
+ if sample_packing and "position_ids" in record:
915
+ record["position_ids"] = record["position_ids"][-new_max_length:]
916
+ if "epoch_times" in record:
917
+ record["epoch_times"] = self._convert_to_tensor(
918
+ record["epoch_times"][-new_max_length:]
919
+ )
517
920
  if self.include_values:
518
921
  record["value_indicators"] = record["value_indicators"][
519
922
  -new_max_length:
@@ -524,36 +927,6 @@ class CehrGptDataCollator:
524
927
  -new_max_length:
525
928
  ]
526
929
 
527
- if not sample_packing:
528
- # Finally we add the end token to the end of the sequence
529
- record["input_ids"] = torch.concat(
530
- [
531
- self._convert_to_tensor(record["input_ids"]),
532
- self._convert_to_tensor([self.tokenizer.end_token_id]),
533
- ]
534
- )
535
- if self.include_values:
536
- record["value_indicators"] = torch.concat(
537
- [
538
- self._convert_to_tensor(record["value_indicators"]),
539
- self._convert_to_tensor([False]),
540
- ]
541
- ).to(torch.bool)
542
- record["values"] = torch.concat(
543
- [
544
- self._convert_to_tensor(record["values"]),
545
- self._convert_to_tensor(
546
- [self.tokenizer.pad_value_token_id]
547
- ),
548
- ]
549
- )
550
- if self.include_ttv_prediction:
551
- record["time_to_visits"] = torch.concat(
552
- [
553
- record["time_to_visits"],
554
- self._convert_to_tensor([-100.0]),
555
- ]
556
- )
557
930
  return record
558
931
 
559
932
 
@@ -584,34 +957,46 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
584
957
 
585
958
  for idx, example in enumerate(examples):
586
959
 
587
- # If the sample length exceeds the model's capacity, truncate this example
960
+ # We only add an end token if the patient sequence could fit in the entire context window
588
961
  add_end_token = (
589
962
  len(example["input_ids"]) <= self.max_position_embeddings
590
963
  and self.add_end_token_in_sample_packing
591
964
  )
592
-
965
+ # If the sample length exceeds the model's capacity, truncate this example
593
966
  if len(example["input_ids"]) > self.max_position_embeddings:
594
967
  example = self.generate_start_end_index(
595
- example, self.max_position_embeddings
968
+ example, False, self.max_position_embeddings
596
969
  )
597
970
 
971
+ add_eos_token = add_end_token | self.add_linear_prob_token
972
+ additional_tokens = []
973
+ if add_end_token:
974
+ additional_tokens.append(self.tokenizer.end_token_id)
975
+ elif self.add_linear_prob_token:
976
+ # Backward compatible
977
+ linear_prob_token_id = (
978
+ self.tokenizer.linear_token_id
979
+ if self.tokenizer.linear_token_id is not None
980
+ else self.tokenizer.oov_token_id
981
+ )
982
+ additional_tokens.append(linear_prob_token_id)
983
+ additional_tokens.append(self.tokenizer.pad_token_id)
598
984
  input_ids = example["input_ids"]
599
985
  # We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
600
986
  # If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
601
987
  # we only add [PAD] token to the end of the sequence because it's not finished
602
- current_input_ids.extend(
603
- list(input_ids)
604
- + (
605
- [self.tokenizer.end_token_id, self.tokenizer.pad_token_id]
606
- if add_end_token
607
- else [self.tokenizer.pad_token_id]
608
- )
609
- )
988
+ current_input_ids.extend(list(input_ids) + additional_tokens)
610
989
  current_attention_mask.extend(
611
- np.ones_like(input_ids).tolist() + ([1, 0] if add_end_token else [0])
990
+ np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
991
+ )
992
+ num_tokens_to_pad = 1 + int(add_eos_token)
993
+ current_position_ids.extend(
994
+ np.clip(
995
+ list(range(len(input_ids) + num_tokens_to_pad)),
996
+ 0,
997
+ self.max_position_embeddings - 1,
998
+ )
612
999
  )
613
- num_tokens_to_pad = 1 + int(add_end_token)
614
- current_position_ids.extend(list(range(len(input_ids) + num_tokens_to_pad)))
615
1000
  if self.include_values:
616
1001
  current_value_indicators.extend(
617
1002
  list(example["value_indicators"]) + [False] * num_tokens_to_pad
@@ -633,9 +1018,10 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
633
1018
  if "classifier_label" in example:
634
1019
  current_labels.append(example["classifier_label"])
635
1020
 
636
- assert (
637
- len(current_input_ids) <= self.max_tokens_per_batch
638
- ), f"the total number of tokens in the packed sequence should be less than { self.max_tokens_per_batch}"
1021
+ assert len(current_input_ids) <= self.max_tokens_per_batch, (
1022
+ f"The total number of tokens in the packed sequence should be less than {self.max_tokens_per_batch}\n"
1023
+ f"But the total number of tokens is: {len(current_input_ids)}"
1024
+ )
639
1025
  packed_example = {
640
1026
  "input_ids": current_input_ids,
641
1027
  "attention_mask": current_attention_mask,