cehrgpt 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ import os
2
+
3
+ import polars as pl
4
+
5
+ from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
6
+
7
+
8
+ def main(args):
9
+ dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
10
+ time_token_frequency_df = (
11
+ dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
12
+ .filter(pl.col("concept_id").map_elements(is_att_token))
13
+ .with_columns(
14
+ pl.col("concept_id")
15
+ .map_elements(extract_time_interval_in_days)
16
+ .alias("time_interval")
17
+ )
18
+ )
19
+ results = time_token_frequency_df.select(
20
+ pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
21
+ ).to_dicts()[0]
22
+ print(results)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+
28
+ parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
29
+ parser.add_argument(
30
+ "--input_dir",
31
+ dest="input_dir",
32
+ action="store",
33
+ help="The path for where the input data folder",
34
+ required=True,
35
+ )
36
+ main(parser.parse_args())
@@ -23,6 +23,7 @@ CEHRGPT_COLUMNS = [
23
23
  "num_of_visits",
24
24
  "values",
25
25
  "value_indicators",
26
+ "epoch_times",
26
27
  ]
27
28
 
28
29
  TRANSFORMER_COLUMNS = ["input_ids"]
@@ -1,21 +1,28 @@
1
+ import copy
1
2
  import random
2
- from typing import Any, Dict, Optional
3
+ from typing import Any, Dict, List, Optional
3
4
 
4
5
  import numpy as np
5
6
  import torch
6
7
  from torch.nn.utils.rnn import pad_sequence
8
+ from transformers.utils import logging
7
9
 
8
10
  from cehrgpt.gpt_utils import (
9
11
  DEMOGRAPHIC_PROMPT_SIZE,
10
12
  collect_demographic_prompts_at_visits,
11
13
  extract_time_interval_in_days,
14
+ extract_time_interval_in_hours,
12
15
  is_att_token,
13
16
  is_inpatient_att_token,
17
+ is_inpatient_hour_token,
18
+ is_visit_end,
14
19
  random_slice_gpt_sequence,
15
20
  )
16
21
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
17
22
 
23
+ TIME_TO_EVENT_MAX_TIME = 3650
18
24
  INPATIENT_STAY_DURATION_LIMIT = 30
25
+ LOG = logging.get_logger("transformers")
19
26
 
20
27
 
21
28
  class CehrGptDataCollator:
@@ -27,20 +34,18 @@ class CehrGptDataCollator:
27
34
  include_values: bool = False,
28
35
  include_ttv_prediction: bool = False,
29
36
  use_sub_time_tokenization: bool = False,
37
+ include_motor_time_to_event: bool = False,
38
+ motor_tte_vocab_size: int = 0,
39
+ motor_num_time_pieces: int = 8,
30
40
  pretraining: bool = True,
31
41
  include_demographics: bool = False,
42
+ add_linear_prob_token: bool = False,
32
43
  ):
33
44
  self.tokenizer = tokenizer
34
45
  self.max_length = max_length
35
- # Pre-compute these so we can use them later on
36
- # We used VS for the historical data, currently, we use the new [VS] for the newer data
37
- # so we need to check both cases.
38
- self.vs_token_id = tokenizer._convert_token_to_id("VS")
39
- if self.vs_token_id == tokenizer._oov_token_id:
40
- self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
41
- self.ve_token_id = tokenizer._convert_token_to_id("VE")
42
- if self.ve_token_id == tokenizer._oov_token_id:
43
- self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
46
+
47
+ self.vs_token_id = tokenizer.vs_token_id
48
+ self.ve_token_id = tokenizer.ve_token_id
44
49
 
45
50
  self.shuffle_records = shuffle_records
46
51
  self.include_values = include_values
@@ -48,6 +53,20 @@ class CehrGptDataCollator:
48
53
  self.use_sub_time_tokenization = use_sub_time_tokenization
49
54
  self.pretraining = pretraining
50
55
  self.include_demographics = include_demographics
56
+ self.add_linear_prob_token = add_linear_prob_token
57
+
58
+ # MOTOR TTE configuration
59
+ if include_motor_time_to_event:
60
+ assert motor_tte_vocab_size > 0, (
61
+ f"motor_tte_vocab_size must be greater than 0 "
62
+ f"when include_motor_time_to_event is set to True. "
63
+ f"But motor_tte_vocab_size: {motor_tte_vocab_size} is provided"
64
+ )
65
+
66
+ self.include_motor_time_to_event = include_motor_time_to_event
67
+ self.motor_tte_vocab_size = motor_tte_vocab_size
68
+ self.motor_num_time_pieces = motor_num_time_pieces
69
+ self.motor_time_interval = TIME_TO_EVENT_MAX_TIME // motor_num_time_pieces
51
70
 
52
71
  if self.use_sub_time_tokenization:
53
72
  token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
@@ -88,6 +107,8 @@ class CehrGptDataCollator:
88
107
  ):
89
108
  return -100
90
109
  return time_to_visit
110
+ elif is_inpatient_hour_token(c):
111
+ return extract_time_interval_in_hours(c) / 24
91
112
  return -100
92
113
  except ValueError:
93
114
  return -100
@@ -95,8 +116,8 @@ class CehrGptDataCollator:
95
116
  return [float(default_value(_)) for _ in concept_ids]
96
117
 
97
118
  def __call__(self, examples):
98
-
99
- examples = [self.generate_start_end_index(_) for _ in examples]
119
+ sample_packing = getattr(self, "sample_packing", False)
120
+ examples = [self.generate_start_end_index(_, sample_packing) for _ in examples]
100
121
  examples = [self.random_sort(_) for _ in examples]
101
122
  batch = {}
102
123
 
@@ -153,7 +174,7 @@ class CehrGptDataCollator:
153
174
  pad_sequence(
154
175
  batch_position_ids,
155
176
  batch_first=True,
156
- padding_value=self.max_length,
177
+ padding_value=0,
157
178
  ).to(torch.int64)
158
179
  )
159
180
 
@@ -194,6 +215,126 @@ class CehrGptDataCollator:
194
215
  )
195
216
  )
196
217
 
218
+ if self.include_motor_time_to_event:
219
+ examples_with_motor_tte = [
220
+ self.create_time_to_event_labels(_) for _ in examples
221
+ ]
222
+ batch_motor_time_to_event_vectors = [
223
+ self._try_reverse_tensor(
224
+ self._convert_to_tensor(example["time_to_event_vectors"])
225
+ )
226
+ for example in examples_with_motor_tte
227
+ ]
228
+ batch_motor_event_indicators = [
229
+ self._try_reverse_tensor(
230
+ self._convert_to_tensor(example["event_indicators"])
231
+ )
232
+ for example in examples_with_motor_tte
233
+ ]
234
+ batch_motor_time_to_event_to_include = [
235
+ self._try_reverse_tensor(
236
+ self._convert_to_tensor(example["time_to_event_to_include"])
237
+ )
238
+ for example in examples_with_motor_tte
239
+ ]
240
+ batch_motor_time_indicators = [
241
+ self._try_reverse_tensor(
242
+ self._convert_to_tensor(example["time_indicators"])
243
+ )
244
+ for example in examples_with_motor_tte
245
+ ]
246
+
247
+ batch_motor_time_to_event_vectors = torch.concat(
248
+ batch_motor_time_to_event_vectors, dim=0
249
+ ).to(torch.float32)
250
+
251
+ # If every example in the batch only contains one visit, there would be no labels generated for MOTOR TTE
252
+ # we only create the labels when any example has more than one visit
253
+ if batch_motor_time_to_event_vectors.dim() <= 1:
254
+ LOG.warning(
255
+ "There are no MOTOR TTE labels generated for this batch "
256
+ "because every example in this batch only contains one visit."
257
+ )
258
+ else:
259
+ batch_size = len(examples)
260
+ length, num_time_pieces, motor_tte_vocab_size = (
261
+ batch_motor_time_to_event_vectors.shape
262
+ )
263
+ padded_length = batch_size - length % batch_size
264
+ batch["motor_time_to_event_vectors"] = (
265
+ torch.concat(
266
+ [
267
+ batch_motor_time_to_event_vectors,
268
+ torch.full(
269
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
270
+ 0.0,
271
+ ),
272
+ ],
273
+ dim=0,
274
+ )
275
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
276
+ .to(torch.float32)
277
+ )
278
+
279
+ # Motor event indicators that indicate there is an event occurred in this time interval
280
+ batch_motor_event_indicators = torch.concat(
281
+ batch_motor_event_indicators, dim=0
282
+ ).to(torch.bool)
283
+ batch["motor_event_indicators"] = (
284
+ torch.concat(
285
+ [
286
+ batch_motor_event_indicators,
287
+ torch.full(
288
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
289
+ False,
290
+ ),
291
+ ],
292
+ dim=0,
293
+ )
294
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
295
+ .to(torch.bool)
296
+ )
297
+
298
+ # Input to indicate whether the visit should be included for TTE predictions
299
+ batch_motor_time_to_event_to_include = torch.concat(
300
+ batch_motor_time_to_event_to_include, dim=0
301
+ ).to(torch.bool)
302
+ batch["motor_time_to_event_to_include"] = (
303
+ torch.concat(
304
+ [
305
+ batch_motor_time_to_event_to_include,
306
+ torch.full((padded_length,), False),
307
+ ],
308
+ dim=0,
309
+ ).to(torch.bool)
310
+ ).reshape((batch_size, -1))
311
+
312
+ # Motor time indicators that indicate whether there are neither clinical events nor censor events
313
+ batch_motor_time_indicators = torch.concat(
314
+ batch_motor_time_indicators, dim=0
315
+ ).to(torch.bool)
316
+ batch["motor_time_indicators"] = (
317
+ torch.concat(
318
+ [
319
+ batch_motor_time_indicators,
320
+ torch.full(
321
+ (padded_length, num_time_pieces, motor_tte_vocab_size),
322
+ False,
323
+ ),
324
+ ],
325
+ dim=0,
326
+ )
327
+ .reshape((batch_size, -1, num_time_pieces, motor_tte_vocab_size))
328
+ .to(torch.bool)
329
+ )
330
+
331
+ batch["motor_end_index"] = torch.concat(
332
+ [
333
+ torch.full((length, 1), 1, dtype=torch.int32),
334
+ torch.full((padded_length, 1), 0, dtype=torch.int32),
335
+ ]
336
+ ).reshape((batch_size, -1))
337
+
197
338
  if self.include_values:
198
339
  batch_value_indicators = [
199
340
  self._try_reverse_tensor(
@@ -281,6 +422,193 @@ class CehrGptDataCollator:
281
422
 
282
423
  return batch
283
424
 
425
+ def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
426
+ """
427
+ Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
428
+
429
+ Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
430
+ to the occurrence of future motor-related events.
431
+
432
+ Args:
433
+ record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
434
+ This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
435
+
436
+ Returns:
437
+ Dict[str, Any]: The updated input record with added keys:
438
+ - "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
439
+ - "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
440
+ """
441
+ input_ids = record["input_ids"]
442
+ sample_packing = getattr(self, "sample_packing", False)
443
+
444
+ if isinstance(input_ids, torch.Tensor):
445
+ input_ids = input_ids.detach().tolist()
446
+
447
+ # This potentially contains packed samples, we need to handle that
448
+ packed_concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
449
+ pad_indices = []
450
+ if sample_packing:
451
+ # We start from the first index
452
+ for i in range(len(packed_concept_ids)):
453
+ if packed_concept_ids[i] == self.tokenizer.pad_token:
454
+ # If we encounter consecutive pads, we should break out of the loop
455
+ if pad_indices and pad_indices[-1] == self.tokenizer.pad_token:
456
+ break
457
+ pad_indices.append(i)
458
+
459
+ # If we did not find a pad, that means the whole sequence belongs to one sample
460
+ if len(pad_indices) == 0:
461
+ pad_indices.append(len(packed_concept_ids))
462
+
463
+ timepiece_time_to_event_vectors = []
464
+ timepiece_event_indicators = []
465
+ timepiece_indicators = []
466
+ time_to_event_to_includes = []
467
+
468
+ for start_index, end_index in zip([0] + pad_indices[:-1], pad_indices):
469
+ concept_ids = packed_concept_ids[start_index:end_index]
470
+ if concept_ids[0] == self.tokenizer.pad_token:
471
+ concept_ids.pop(0)
472
+ time_to_event_vectors = []
473
+ global_event_indicators = []
474
+
475
+ # First collect TTE data in reverse chronological order
476
+ censor_times = []
477
+ time_to_event_data: List[Dict[str, int]] = []
478
+ time_to_event_dict: Dict[str, int] = {}
479
+ time_to_event_to_include: List[bool] = []
480
+ next_future_visit_concepts = set()
481
+ time_interval = 0
482
+
483
+ # Reverse walk through concept_ids to calculate TTE from each [VE] point
484
+ for concept_id in reversed(concept_ids):
485
+ if is_visit_end(concept_id):
486
+ # Update TTE for existing concepts, or add new ones seen in this visit
487
+ for existing_concept_id in list(time_to_event_dict.keys()):
488
+ if existing_concept_id in next_future_visit_concepts:
489
+ time_to_event_dict[existing_concept_id] = time_interval
490
+ else:
491
+ time_to_event_dict[existing_concept_id] += time_interval
492
+
493
+ for next_concept_id in next_future_visit_concepts:
494
+ if next_concept_id not in time_to_event_dict:
495
+ time_to_event_dict[next_concept_id] = time_interval
496
+
497
+ # If the next visit occurs on the same day as the previous one, we don't want to do TTE for the
498
+ # previous visit
499
+ time_to_event_to_include.append(time_interval > 0)
500
+ time_to_event_data.append(copy.deepcopy(time_to_event_dict))
501
+ # Record the censor time at the end of the visit
502
+ if censor_times:
503
+ censor_times.append(censor_times[-1] + time_interval)
504
+ else:
505
+ censor_times.append(time_interval)
506
+ time_interval = 0
507
+ next_future_visit_concepts.clear()
508
+
509
+ elif is_att_token(concept_id):
510
+ time_interval += extract_time_interval_in_days(concept_id)
511
+
512
+ elif self.tokenizer.is_motor_time_to_event_code(concept_id):
513
+ next_future_visit_concepts.add(concept_id)
514
+
515
+ if len(time_to_event_data) == 0:
516
+ LOG.info(
517
+ "Vist end event is not detected for this sample, and is skipped for MOTOR tasks."
518
+ "It's likely this sample contains a long admission. length: %s, concept_ids[-10:] %s",
519
+ len(concept_ids),
520
+ concept_ids[-10:],
521
+ )
522
+ continue
523
+
524
+ # Reverse back to chronological order for final labels
525
+ time_to_event_data.reverse()
526
+ censor_times.reverse()
527
+ time_to_event_to_include.reverse()
528
+
529
+ for censor_time, visit_tte_data in zip(censor_times, time_to_event_data):
530
+ time_to_event_vector = np.full(
531
+ self.tokenizer.motor_tte_vocab_size,
532
+ fill_value=censor_time,
533
+ dtype=np.int32,
534
+ )
535
+ event_indicator = np.zeros(
536
+ self.tokenizer.motor_tte_vocab_size,
537
+ dtype=np.int32,
538
+ )
539
+ visit_token_ids = [
540
+ self.tokenizer.get_motor_token_id(concept_id)
541
+ for concept_id in visit_tte_data.keys()
542
+ ]
543
+ visit_tte_values = list(visit_tte_data.values())
544
+
545
+ time_to_event_vector[visit_token_ids] = visit_tte_values
546
+ event_indicator[visit_token_ids] = 1 # not censored (event occurred)
547
+
548
+ time_to_event_vectors.append(time_to_event_vector)
549
+ global_event_indicators.append(event_indicator)
550
+
551
+ time_to_event_vectors = np.asarray(time_to_event_vectors)
552
+ global_event_indicators = np.asarray(global_event_indicators).astype(bool)
553
+ n_visits = len(time_to_event_vectors)
554
+
555
+ timepiece_time_to_event_vector = np.full(
556
+ (
557
+ self.motor_num_time_pieces,
558
+ n_visits,
559
+ self.tokenizer.motor_tte_vocab_size,
560
+ ),
561
+ fill_value=0,
562
+ dtype=np.int32,
563
+ )
564
+ timepiece_event_indicator = np.zeros(
565
+ (
566
+ self.motor_num_time_pieces,
567
+ n_visits,
568
+ self.tokenizer.motor_tte_vocab_size,
569
+ ),
570
+ dtype=bool,
571
+ )
572
+ timepiece_indicator = np.zeros(
573
+ (
574
+ self.motor_num_time_pieces,
575
+ n_visits,
576
+ self.tokenizer.motor_tte_vocab_size,
577
+ ),
578
+ dtype=bool,
579
+ )
580
+
581
+ # Putting the event time and censor time into the corresponding time bins
582
+ for bin_num in range(self.motor_num_time_pieces):
583
+ start = self.motor_time_interval * bin_num
584
+ end = self.motor_time_interval * (bin_num + 1)
585
+ time_in_bin = np.clip(time_to_event_vectors - start, 0, end - start)
586
+ timepiece_time_to_event_vector[bin_num] = time_in_bin
587
+ event_indicator = (
588
+ global_event_indicators
589
+ & (start <= time_to_event_vectors)
590
+ & (time_to_event_vectors < end)
591
+ )
592
+ timepiece_event_indicator[bin_num] = event_indicator
593
+ timepiece_indicator[bin_num] = time_in_bin > 0 | event_indicator
594
+
595
+ timepiece_time_to_event_vectors.append(
596
+ timepiece_time_to_event_vector.swapaxes(0, 1)
597
+ )
598
+ timepiece_event_indicators.append(timepiece_event_indicator.swapaxes(0, 1))
599
+ timepiece_indicators.append(timepiece_indicator.swapaxes(0, 1))
600
+ time_to_event_to_includes.append(np.asarray(time_to_event_to_include))
601
+
602
+ record["time_to_event_vectors"] = np.concatenate(
603
+ timepiece_time_to_event_vectors, axis=0
604
+ )
605
+ record["event_indicators"] = np.concatenate(timepiece_event_indicators, axis=0)
606
+ record["time_indicators"] = np.concatenate(timepiece_indicators, axis=0)
607
+ record["time_to_event_to_include"] = np.concatenate(
608
+ time_to_event_to_includes, axis=0
609
+ )
610
+ return record
611
+
284
612
  def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
285
613
 
286
614
  if not self.shuffle_records:
@@ -317,14 +645,16 @@ class CehrGptDataCollator:
317
645
  return record
318
646
 
319
647
  def generate_start_end_index(
320
- self, record: Dict[str, Any], max_length_allowed: Optional[int] = None
648
+ self,
649
+ record: Dict[str, Any],
650
+ sample_packing: bool,
651
+ max_length_allowed: Optional[int] = None,
321
652
  ) -> Dict[str, Any]:
322
653
  """Adding the start and end indices to extract a portion of the patient sequence."""
323
654
  # concept_ids will be used to for time to event predictions and identifying the visit starts
324
655
  max_length_allowed = (
325
656
  self.max_length if max_length_allowed is None else max_length_allowed
326
657
  )
327
- sample_packing = getattr(self, "sample_packing", False)
328
658
  input_ids = record["input_ids"]
329
659
  if isinstance(input_ids, torch.Tensor):
330
660
  input_ids = input_ids.detach().tolist()
@@ -341,13 +671,25 @@ class CehrGptDataCollator:
341
671
  [self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
342
672
  )
343
673
 
674
+ # If linear token exists, we will use it, otherwise we default to the OOV token
675
+ linear_token_id = (
676
+ self.tokenizer.linear_token_id
677
+ if self.tokenizer.linear_token_id
678
+ else self.tokenizer.oov_token_id
679
+ )
680
+ eos_token = (
681
+ linear_token_id
682
+ if self.add_linear_prob_token
683
+ else self.tokenizer.end_token_id
684
+ )
685
+
344
686
  # Return the record directly if the actual sequence length is less than the max sequence
345
687
  if seq_length <= new_max_length:
346
688
  if not sample_packing:
347
689
  record["input_ids"] = torch.concat(
348
690
  [
349
691
  self._convert_to_tensor(record["input_ids"]),
350
- self._convert_to_tensor([self.tokenizer.end_token_id]),
692
+ self._convert_to_tensor([eos_token]),
351
693
  ]
352
694
  )
353
695
  if self.include_values:
@@ -372,7 +714,6 @@ class CehrGptDataCollator:
372
714
  self._convert_to_tensor([-100.0]),
373
715
  ]
374
716
  )
375
-
376
717
  return record
377
718
 
378
719
  if self.pretraining:
@@ -406,7 +747,8 @@ class CehrGptDataCollator:
406
747
  for i in reversed(list(range(0, end_index))):
407
748
  current_token = record["input_ids"][i]
408
749
  if current_token == self.ve_token_id:
409
- end_index = i
750
+ # Plus one because slicing is right exclusive
751
+ end_index = i + 1
410
752
  break
411
753
 
412
754
  record["input_ids"] = record["input_ids"][0:end_index]
@@ -415,6 +757,9 @@ class CehrGptDataCollator:
415
757
  if sample_packing and "attention_mask" in record:
416
758
  record["attention_mask"] = record["attention_mask"][0:end_index]
417
759
 
760
+ if sample_packing and "position_ids" in record:
761
+ record["position_ids"] = record["position_ids"][0:end_index]
762
+
418
763
  if self.include_values:
419
764
  record["value_indicators"] = self._convert_to_tensor(
420
765
  record["value_indicators"][0:end_index]
@@ -495,6 +840,8 @@ class CehrGptDataCollator:
495
840
  record["attention_mask"] = record["attention_mask"][
496
841
  i:end_index
497
842
  ]
843
+ if sample_packing and "position_ids" in record:
844
+ record["position_ids"] = record["position_ids"][i:end_index]
498
845
  if self.include_values:
499
846
  record["value_indicators"] = record["value_indicators"][
500
847
  i:end_index
@@ -514,6 +861,8 @@ class CehrGptDataCollator:
514
861
  record["attention_mask"] = record["attention_mask"][
515
862
  -new_max_length:
516
863
  ]
864
+ if sample_packing and "position_ids" in record:
865
+ record["position_ids"] = record["position_ids"][-new_max_length:]
517
866
  if self.include_values:
518
867
  record["value_indicators"] = record["value_indicators"][
519
868
  -new_max_length:
@@ -529,7 +878,7 @@ class CehrGptDataCollator:
529
878
  record["input_ids"] = torch.concat(
530
879
  [
531
880
  self._convert_to_tensor(record["input_ids"]),
532
- self._convert_to_tensor([self.tokenizer.end_token_id]),
881
+ self._convert_to_tensor([eos_token]),
533
882
  ]
534
883
  )
535
884
  if self.include_values:
@@ -584,34 +933,46 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
584
933
 
585
934
  for idx, example in enumerate(examples):
586
935
 
587
- # If the sample length exceeds the model's capacity, truncate this example
936
+ # We only add an end token if the patient sequence could fit in the entire context window
588
937
  add_end_token = (
589
938
  len(example["input_ids"]) <= self.max_position_embeddings
590
939
  and self.add_end_token_in_sample_packing
591
940
  )
592
-
941
+ # If the sample length exceeds the model's capacity, truncate this example
593
942
  if len(example["input_ids"]) > self.max_position_embeddings:
594
943
  example = self.generate_start_end_index(
595
- example, self.max_position_embeddings
944
+ example, False, self.max_position_embeddings
596
945
  )
597
946
 
947
+ add_eos_token = add_end_token | self.add_linear_prob_token
948
+ additional_tokens = []
949
+ if add_end_token:
950
+ additional_tokens.append(self.tokenizer.end_token_id)
951
+ elif self.add_linear_prob_token:
952
+ # Backward compatible
953
+ linear_prob_token_id = (
954
+ self.tokenizer.linear_token_id
955
+ if self.tokenizer.linear_token_id is not None
956
+ else self.tokenizer.oov_token_id
957
+ )
958
+ additional_tokens.append(linear_prob_token_id)
959
+ additional_tokens.append(self.tokenizer.pad_token_id)
598
960
  input_ids = example["input_ids"]
599
961
  # We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
600
962
  # If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
601
963
  # we only add [PAD] token to the end of the sequence because it's not finished
602
- current_input_ids.extend(
603
- list(input_ids)
604
- + (
605
- [self.tokenizer.end_token_id, self.tokenizer.pad_token_id]
606
- if add_end_token
607
- else [self.tokenizer.pad_token_id]
608
- )
609
- )
964
+ current_input_ids.extend(list(input_ids) + additional_tokens)
610
965
  current_attention_mask.extend(
611
- np.ones_like(input_ids).tolist() + ([1, 0] if add_end_token else [0])
966
+ np.ones_like(input_ids).tolist() + ([1, 0] if add_eos_token else [0])
967
+ )
968
+ num_tokens_to_pad = 1 + int(add_eos_token)
969
+ current_position_ids.extend(
970
+ np.clip(
971
+ list(range(len(input_ids) + num_tokens_to_pad)),
972
+ 0,
973
+ self.max_position_embeddings - 1,
974
+ )
612
975
  )
613
- num_tokens_to_pad = 1 + int(add_end_token)
614
- current_position_ids.extend(list(range(len(input_ids) + num_tokens_to_pad)))
615
976
  if self.include_values:
616
977
  current_value_indicators.extend(
617
978
  list(example["value_indicators"]) + [False] * num_tokens_to_pad
@@ -633,9 +994,10 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
633
994
  if "classifier_label" in example:
634
995
  current_labels.append(example["classifier_label"])
635
996
 
636
- assert (
637
- len(current_input_ids) <= self.max_tokens_per_batch
638
- ), f"the total number of tokens in the packed sequence should be less than { self.max_tokens_per_batch}"
997
+ assert len(current_input_ids) <= self.max_tokens_per_batch, (
998
+ f"The total number of tokens in the packed sequence should be less than {self.max_tokens_per_batch}\n"
999
+ f"But the total number of tokens is: {len(current_input_ids)}"
1000
+ )
639
1001
  packed_example = {
640
1002
  "input_ids": current_input_ids,
641
1003
  "attention_mask": current_attention_mask,