cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import os
5
+ from typing import Any, Dict, Iterable, Optional, Set, Union
6
+
7
+ import polars as pl
8
+ from datasets import Dataset
9
+
10
+
11
+ # Adapted from femr.ontology
12
+ def _get_all_codes_map(batch: Dict[str, Any]) -> Dict[str, Any]:
13
+ result = set()
14
+ for concept_ids in batch["concept_ids"]:
15
+ for concept_id in concept_ids:
16
+ if concept_id.isnumeric():
17
+ result.add(concept_id)
18
+ return {"unique_concept_ids": list(result)}
19
+
20
+
21
+ class Ontology:
22
+ def __init__(self, vocab_path: str):
23
+ """Create an Ontology from an Athena download and an optional meds Code Metadata structure.
24
+
25
+ NOTE: This is an expensive operation.
26
+ It is recommended to create an ontology once and then save/load it as necessary.
27
+ """
28
+ # Load from code metadata
29
+ self.parents_map: Dict[str, Set[str]] = collections.defaultdict(set)
30
+ self.concept_vocabulary_map: Dict[str, str] = collections.defaultdict(str)
31
+ self.concept_domain_map: Dict[str, str] = collections.defaultdict(str)
32
+
33
+ # Load from the athena path ...
34
+ concept = pl.scan_parquet(os.path.join(vocab_path, "concept/*parquet"))
35
+ vocabulary_id_col = pl.col("vocabulary_id")
36
+ concept_id_col = pl.col("concept_id").cast(pl.String)
37
+ domain_id_col = pl.col("domain_id").cast(pl.String)
38
+
39
+ processed_concepts = (
40
+ concept.select(
41
+ concept_id_col,
42
+ domain_id_col,
43
+ vocabulary_id_col,
44
+ pl.col("standard_concept").is_null(),
45
+ )
46
+ .collect()
47
+ .rows()
48
+ )
49
+
50
+ non_standard_concepts = set()
51
+
52
+ for concept_id, domain_id, vocabulary_id, is_non_standard in processed_concepts:
53
+ # We don't want to override code metadata
54
+ if concept_id not in self.concept_vocabulary_map:
55
+ self.concept_vocabulary_map[concept_id] = vocabulary_id
56
+
57
+ if concept_id not in self.concept_domain_map:
58
+ self.concept_domain_map[concept_id] = domain_id
59
+ # We don't want to override code metadata
60
+ if is_non_standard:
61
+ non_standard_concepts.add(concept_id)
62
+
63
+ relationship = pl.scan_parquet(
64
+ os.path.join(vocab_path, "concept_relationship/*parquet")
65
+ )
66
+ relationship_id = pl.col("relationship_id")
67
+ relationship = relationship.filter(
68
+ relationship_id == "Maps to",
69
+ pl.col("concept_id_1") != pl.col("concept_id_2"),
70
+ )
71
+ for concept_id_1, concept_id_2 in (
72
+ relationship.select(
73
+ pl.col("concept_id_1").cast(pl.String),
74
+ pl.col("concept_id_2").cast(pl.String),
75
+ )
76
+ .collect()
77
+ .rows()
78
+ ):
79
+ if concept_id_1 in non_standard_concepts:
80
+ self.parents_map[concept_id_1].add(concept_id_2)
81
+
82
+ ancestor = pl.scan_parquet(
83
+ os.path.join(vocab_path, "concept_ancestor/*parquet")
84
+ )
85
+ ancestor = ancestor.filter(pl.col("min_levels_of_separation") == 1)
86
+ for concept_id, parent_concept_id in (
87
+ ancestor.select(
88
+ pl.col("descendant_concept_id").cast(pl.String),
89
+ pl.col("ancestor_concept_id").cast(pl.String),
90
+ )
91
+ .collect()
92
+ .rows()
93
+ ):
94
+ self.parents_map[concept_id].add(parent_concept_id)
95
+ self.all_parents_map: Dict[str, Set[str]] = {}
96
+
97
+ def get_domain(self, concept_id: Union[str, int]) -> Optional[str]:
98
+ return self.concept_domain_map.get(str(concept_id), None)
99
+
100
+ def prune_to_dataset(
101
+ self,
102
+ dataset: Dataset,
103
+ remove_ontologies: Set[str] = set(),
104
+ num_proc: int = 4,
105
+ batch_size: int = 1024,
106
+ ) -> None:
107
+ mapped_dataset = dataset.map(
108
+ _get_all_codes_map,
109
+ batched=True,
110
+ batch_size=batch_size,
111
+ remove_columns=dataset.column_names,
112
+ num_proc=num_proc,
113
+ )
114
+ valid_concept_ids = set(mapped_dataset["unique_concept_ids"])
115
+ all_parents = set()
116
+ for concept_id in valid_concept_ids:
117
+ all_parents |= self.get_all_parents(concept_id)
118
+
119
+ def is_valid(c: str):
120
+ ontology = self.concept_vocabulary_map.get(c, "")
121
+ return (c in valid_concept_ids) or (
122
+ (ontology not in remove_ontologies) and (c in all_parents)
123
+ )
124
+
125
+ concept_ids = set(self.parents_map.keys())
126
+ for concept_id in concept_ids:
127
+ m: Any
128
+ if is_valid(concept_id):
129
+ for m in (self.parents_map, self.concept_vocabulary_map):
130
+ m[concept_id] = {a for a in m[concept_id] if is_valid(a)}
131
+ else:
132
+ for m in (self.parents_map, self.concept_vocabulary_map):
133
+ if concept_id in m:
134
+ del m[concept_id]
135
+
136
+ self.all_parents_map = {}
137
+
138
+ # Prime the pump
139
+ for concept_id in self.parents_map.keys():
140
+ self.get_all_parents(concept_id)
141
+
142
+ def get_parents(self, code: str) -> Iterable[str]:
143
+ """Get the parents for a given code."""
144
+ return self.parents_map.get(code, set())
145
+
146
+ def get_all_parents(self, code: str) -> Set[str]:
147
+ """Get all parents, including through the ontology."""
148
+ if code not in self.all_parents_map:
149
+ result = {code}
150
+ for parent in self.parents_map.get(code, set()):
151
+ result |= self.get_all_parents(parent)
152
+ self.all_parents_map[code] = result
153
+
154
+ return self.all_parents_map[code]
@@ -47,7 +47,7 @@ def prepare_finetune_dataset(
47
47
  data_args: DataTrainingArguments,
48
48
  training_args: TrainingArguments,
49
49
  cehrgpt_args: CehrGPTArguments,
50
- cache_file_collector: CacheFileCollector,
50
+ cache_file_collector: Optional[CacheFileCollector] = None,
51
51
  ) -> DatasetDict:
52
52
  # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
53
53
  if data_args.is_data_in_meds:
@@ -91,8 +91,9 @@ def prepare_finetune_dataset(
91
91
  "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
92
92
  stats,
93
93
  )
94
- # Clean up the files created from the data generator
95
- cache_file_collector.remove_cache_files()
94
+ if cache_file_collector:
95
+ # Clean up the files created from the data generator
96
+ cache_file_collector.remove_cache_files()
96
97
  dataset = load_from_disk(str(meds_extension_path))
97
98
 
98
99
  train_set = dataset["train"]
@@ -271,7 +272,7 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
271
272
  def extract_cohort_sequences(
272
273
  data_args: DataTrainingArguments,
273
274
  cehrgpt_args: CehrGPTArguments,
274
- cache_file_collector: CacheFileCollector,
275
+ cache_file_collector: Optional[CacheFileCollector] = None,
275
276
  ) -> DatasetDict:
276
277
  """
277
278
  Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
@@ -309,9 +310,18 @@ def extract_cohort_sequences(
309
310
  mapping={
310
311
  "prediction_time": "index_date",
311
312
  "subject_id": "person_id",
313
+ "boolean_value": "label",
312
314
  }
313
315
  )
314
316
  all_person_ids = cohort["person_id"].unique().to_list()
317
+ # In case the label column does not exist, we add a fake column to the dataframe so subsequent process can work
318
+ if "label" not in cohort.columns:
319
+ cohort = cohort.with_columns(
320
+ pl.Series(
321
+ name="label", values=np.zeros_like(cohort["person_id"].to_numpy())
322
+ )
323
+ )
324
+
315
325
  # data_args.observation_window
316
326
  tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
317
327
  filtered_tokenized_dataset = tokenized_dataset.filter(
@@ -353,6 +363,7 @@ def extract_cohort_sequences(
353
363
  num_proc=data_args.preprocessing_num_workers,
354
364
  remove_columns=filtered_tokenized_dataset["train"].column_names,
355
365
  )
356
- cache_file_collector.add_cache_files(filtered_tokenized_dataset)
357
- cache_file_collector.add_cache_files(processed_dataset)
366
+ if cache_file_collector:
367
+ cache_file_collector.add_cache_files(filtered_tokenized_dataset)
368
+ cache_file_collector.add_cache_files(processed_dataset)
358
369
  return processed_dataset
@@ -1,3 +1,4 @@
1
+ import glob
1
2
  import json
2
3
  import os
3
4
  import random
@@ -175,11 +176,6 @@ def model_init(
175
176
  model.config.class_weights = cehrgpt_args.class_weights
176
177
  LOG.info(f"Setting class_weights to {model.config.class_weights}")
177
178
 
178
- # Enable position embeddings when position embeddings are disabled in pre-training
179
- if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
180
- LOG.info(f"Enable the position_embeddings")
181
- model.cehrgpt.enable_position_embeddings()
182
-
183
179
  if model.config.max_position_embeddings < model_args.max_position_embeddings:
184
180
  LOG.info(
185
181
  f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
@@ -379,7 +375,6 @@ def main():
379
375
  SamplePackingCehrGptDataCollator,
380
376
  cehrgpt_args.max_tokens_per_batch,
381
377
  config.max_position_embeddings,
382
- add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
383
378
  )
384
379
  else:
385
380
  trainer_class = Trainer
@@ -406,8 +401,9 @@ def main():
406
401
  )
407
402
 
408
403
  if training_args.do_train:
404
+ output_dir = training_args.output_dir
409
405
  if cehrgpt_args.hyperparameter_tuning:
410
- training_args = perform_hyperparameter_search(
406
+ training_args, run_id = perform_hyperparameter_search(
411
407
  trainer_class,
412
408
  partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
413
409
  processed_dataset,
@@ -416,18 +412,28 @@ def main():
416
412
  model_args,
417
413
  cehrgpt_args,
418
414
  )
419
-
420
- if cehrgpt_args.retrain_with_full:
421
- # Always retrain with the full set when hyperparameter tuning is set to true
422
- retrain_with_full_set(
423
- trainer_class,
424
- model_args,
425
- training_args,
426
- cehrgpt_args,
427
- tokenizer,
428
- processed_dataset,
429
- data_collator,
415
+ # We enforce retraining if cehrgpt_args.hyperparameter_tuning_percentage < 1.0
416
+ cehrgpt_args.retrain_with_full |= (
417
+ cehrgpt_args.hyperparameter_tuning_percentage < 1.0
430
418
  )
419
+ output_dir = os.path.join(training_args.output_dir, f"run-{run_id}")
420
+
421
+ if cehrgpt_args.hyperparameter_tuning and not cehrgpt_args.retrain_with_full:
422
+ folders = glob.glob(os.path.join(output_dir, "checkpoint-*"))
423
+ if len(folders) == 0:
424
+ raise RuntimeError(
425
+ f"There must be a checkpoint folder under {output_dir}"
426
+ )
427
+ checkpoint_dir = folders[0]
428
+ LOG.info("Best trial checkpoint folder: %s", checkpoint_dir)
429
+ for file_name in os.listdir(checkpoint_dir):
430
+ try:
431
+ full_file_name = os.path.join(checkpoint_dir, file_name)
432
+ destination = os.path.join(training_args.output_dir, file_name)
433
+ if os.path.isfile(full_file_name):
434
+ shutil.copy2(full_file_name, destination)
435
+ except Exception as e:
436
+ LOG.error("Failed to copy %s: %s", file_name, str(e))
431
437
  else:
432
438
  # Initialize Trainer for final training on the combined train+val set
433
439
  trainer = trainer_class(
@@ -476,63 +482,6 @@ def main():
476
482
  do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
477
483
 
478
484
 
479
- def retrain_with_full_set(
480
- trainer_class,
481
- model_args: ModelArguments,
482
- training_args: TrainingArguments,
483
- cehrgpt_args: CehrGPTArguments,
484
- tokenizer: CehrGptTokenizer,
485
- dataset: DatasetDict,
486
- data_collator: CehrGptDataCollator,
487
- ) -> None:
488
- """
489
- Retrains a model on the full training and validation dataset for final performance evaluation.
490
-
491
- This function consolidates the training and validation datasets into a single
492
- dataset for final model training, updates the output directory for the final model,
493
- and disables evaluation during training. It resumes from the latest checkpoint if available,
494
- trains the model on the combined dataset, and saves the model along with training metrics
495
- and state information.
496
-
497
- Args:
498
- trainer_class: Trainer or its subclass
499
- model_args (ModelArguments): Model configuration and hyperparameters.
500
- training_args (TrainingArguments): Training configuration, including output directory,
501
- evaluation strategy, and other training parameters.
502
- cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
503
- tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
504
- dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
505
- data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
506
-
507
- Returns:
508
- None
509
- """
510
- # Initialize Trainer for final training on the combined train+val set
511
- full_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])
512
- training_args.output_dir = os.path.join(training_args.output_dir, "full")
513
- LOG.info(
514
- "Final output_dir for final_training_args.output_dir %s",
515
- training_args.output_dir,
516
- )
517
- Path(training_args.output_dir).mkdir(exist_ok=True)
518
- # Disable evaluation
519
- training_args.evaluation_strategy = "no"
520
- checkpoint = get_last_hf_checkpoint(training_args)
521
- final_trainer = trainer_class(
522
- model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
523
- data_collator=data_collator,
524
- args=training_args,
525
- train_dataset=full_dataset,
526
- tokenizer=tokenizer,
527
- )
528
- final_train_result = final_trainer.train(resume_from_checkpoint=checkpoint)
529
- final_trainer.save_model() # Saves the tokenizer too for easy upload
530
- metrics = final_train_result.metrics
531
- final_trainer.log_metrics("train", metrics)
532
- final_trainer.save_metrics("train", metrics)
533
- final_trainer.save_state()
534
-
535
-
536
485
  def do_predict(
537
486
  test_dataloader: DataLoader,
538
487
  model_args: ModelArguments,
@@ -580,7 +529,15 @@ def do_predict(
580
529
  index_dates = batch.pop("index_date").numpy().squeeze()
581
530
  if index_dates.ndim == 0:
582
531
  index_dates = np.asarray([index_dates])
583
- index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
532
+
533
+ index_dates = list(
534
+ map(
535
+ lambda posix_time: datetime.utcfromtimestamp(posix_time).replace(
536
+ tzinfo=None
537
+ ),
538
+ index_dates.tolist(),
539
+ )
540
+ )
584
541
 
585
542
  batch = {k: v.to(device) for k, v in batch.items()}
586
543
  # Forward pass
@@ -644,9 +601,6 @@ def load_lora_model(
644
601
  # Enable include_values when include_values is set to be False during pre-training
645
602
  if model_args.include_values and not model.cehrgpt.include_values:
646
603
  model.cehrgpt.include_values = True
647
- # Enable position embeddings when position embeddings are disabled in pre-training
648
- if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
649
- model.cehrgpt.exclude_position_ids = False
650
604
  if cehrgpt_args.expand_tokenizer:
651
605
  tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
652
606
  # Expand tokenizer to adapt to the finetuning dataset
@@ -2,7 +2,9 @@ import os
2
2
  from functools import partial
3
3
  from typing import Optional, Union
4
4
 
5
+ import datasets
5
6
  import numpy as np
7
+ import pandas as pd
6
8
  import torch
7
9
  import torch.distributed as dist
8
10
  from cehrbert.data_generators.hf_data_generator.meds_utils import (
@@ -20,7 +22,7 @@ from cehrbert.runners.runner_util import (
20
22
  load_parquet_as_dataset,
21
23
  )
22
24
  from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
23
- from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
25
+ from transformers import EarlyStoppingCallback, Trainer, set_seed
24
26
  from transformers.trainer_utils import is_main_process
25
27
  from transformers.utils import is_flash_attn_2_available, logging
26
28
 
@@ -34,6 +36,7 @@ from cehrgpt.models.config import CEHRGPTConfig
34
36
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
35
37
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
36
38
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
39
+ from cehrgpt.omop.ontology import Ontology
37
40
  from cehrgpt.runners.data_utils import get_torch_dtype
38
41
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
39
42
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
@@ -70,68 +73,64 @@ def load_and_create_tokenizer(
70
73
  data_args: DataTrainingArguments,
71
74
  model_args: ModelArguments,
72
75
  cehrgpt_args: CehrGPTArguments,
73
- dataset: Optional[Union[Dataset, DatasetDict]] = None,
76
+ dataset: Union[Dataset, DatasetDict],
74
77
  ) -> CehrGptTokenizer:
75
78
 
76
- concept_name_mapping = {}
77
- allowed_motor_codes = list()
78
- if cehrgpt_args.concept_dir:
79
- import pandas as pd
80
- from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
81
- from meds.schema import death_code
82
-
83
- LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
84
- concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
85
- LOG.info(
86
- "Creating concept name mapping and motor_time_to_event_codes from disk at %s",
87
- cehrgpt_args.concept_dir,
88
- )
89
- for row in concept_pd.itertuples():
90
- concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
91
- row, "concept_name"
92
- )
93
- if (
94
- cehrgpt_args.include_motor_time_to_event
95
- and getattr(row, "domain_id")
96
- in ["Condition", "Procedure", "Drug", "Visit"]
97
- and getattr(row, "standard_concept") == "S"
98
- ):
99
- allowed_motor_codes.append(str(getattr(row, "concept_id")))
100
- LOG.info(
101
- "Adding death codes for MOTOR TTE predictions: %s",
102
- [DEATH_TOKEN, death_code],
103
- )
104
- allowed_motor_codes.extend([DEATH_TOKEN, death_code])
105
79
  # Try to load the pretrained tokenizer
106
80
  tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
107
- try:
108
- tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
109
- except Exception as e:
110
- LOG.warning(e)
111
- if dataset is None:
81
+ if not tokenizer_exists(tokenizer_abspath):
82
+ if cehrgpt_args.include_motor_time_to_event and not cehrgpt_args.vocab_dir:
112
83
  raise RuntimeError(
113
- f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
114
- f"Tried to create the tokenizer, however the dataset is not provided."
84
+ "motor_vocab_dir must be specified if include_motor_time_to_event is True"
85
+ )
86
+ ontology: Optional[Ontology] = None
87
+ concept_name_mapping = {}
88
+ if cehrgpt_args.vocab_dir:
89
+ LOG.info("Loading concept data from disk at %s", cehrgpt_args.vocab_dir)
90
+ concept_pd = pd.read_parquet(
91
+ os.path.join(cehrgpt_args.vocab_dir, "concept")
115
92
  )
93
+ for row in concept_pd.itertuples():
94
+ concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
95
+ row, "concept_name"
96
+ )
97
+
98
+ if cehrgpt_args.motor_use_ontology:
99
+ LOG.info("Creating ontology for MOTOR TTE predictions")
100
+ ontology = Ontology(cehrgpt_args.vocab_dir)
101
+ train_val_dataset = datasets.concatenate_datasets(
102
+ [dataset["train"], dataset["validation"]]
103
+ )
104
+ ontology.prune_to_dataset(
105
+ train_val_dataset,
106
+ num_proc=data_args.preprocessing_num_workers,
107
+ remove_ontologies={"SPL", "HemOnc", "LOINC"},
108
+ )
109
+
116
110
  LOG.info("Started training the tokenizer ...")
111
+ train_val_dataset = datasets.concatenate_datasets(
112
+ [dataset["train"], dataset["validation"]]
113
+ )
117
114
  tokenizer = CehrGptTokenizer.train_tokenizer(
118
- dataset,
115
+ train_val_dataset,
119
116
  concept_name_mapping,
120
117
  data_args,
121
118
  PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
122
- allowed_motor_codes if cehrgpt_args.include_motor_time_to_event else None,
123
- (
119
+ num_motor_tasks=(
124
120
  cehrgpt_args.num_motor_tasks
125
121
  if cehrgpt_args.include_motor_time_to_event
126
122
  else None
127
123
  ),
128
124
  apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
129
125
  min_prevalence=cehrgpt_args.min_prevalence,
126
+ ontology=ontology,
130
127
  )
131
128
  LOG.info("Finished training the tokenizer ...")
132
129
  tokenizer.save_pretrained(tokenizer_abspath)
133
130
  LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
134
-
131
+ else:
132
+ LOG.info("The tokenizer exists and will be loaded from %s", tokenizer_abspath)
133
+ tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
135
134
  return tokenizer
136
135
 
137
136
 
@@ -187,7 +186,10 @@ def load_and_create_model(
187
186
 
188
187
  model_args_cehrgpt = model_args.as_dict()
189
188
  model_args_cehrgpt.pop("attn_implementation")
189
+ # CEHR-GPT does not support this anymore
190
+ model_args_cehrgpt.pop("exclude_position_ids")
190
191
  model_config = CEHRGPTConfig(
192
+ activation_function=cehrgpt_args.activation_function,
191
193
  vocab_size=tokenizer.vocab_size,
192
194
  value_vocab_size=tokenizer.value_vocab_size,
193
195
  time_token_vocab_size=tokenizer.time_token_vocab_size,
@@ -207,6 +209,7 @@ def load_and_create_model(
207
209
  n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
208
210
  use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
209
211
  pretrained_embedding_dim=pretrained_embedding_dim,
212
+ apply_rotary=cehrgpt_args.apply_rotary,
210
213
  sample_packing_max_positions=(
211
214
  cehrgpt_args.max_tokens_per_batch
212
215
  if cehrgpt_args.sample_packing
@@ -217,6 +220,8 @@ def load_and_create_model(
217
220
  motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
218
221
  motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
219
222
  ve_token_id=tokenizer.ve_token_id,
223
+ n_inner=cehrgpt_args.inner_dim,
224
+ decoder_mlp=cehrgpt_args.decoder_mlp,
220
225
  **model_args_cehrgpt,
221
226
  )
222
227
 
@@ -235,7 +240,6 @@ def load_and_create_model(
235
240
 
236
241
  def main():
237
242
  cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
238
-
239
243
  if cehrgpt_args.sample_packing and data_args.streaming:
240
244
  raise RuntimeError(
241
245
  f"sample_packing is not supported when streaming is enabled, please set streaming to False"
@@ -530,7 +534,6 @@ def main():
530
534
  SamplePackingCehrGptDataCollator,
531
535
  cehrgpt_args.max_tokens_per_batch,
532
536
  model_args.max_position_embeddings,
533
- add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
534
537
  )
535
538
  else:
536
539
  trainer_class = Trainer
@@ -552,6 +555,7 @@ def main():
552
555
  include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
553
556
  motor_tte_vocab_size=model.config.motor_tte_vocab_size,
554
557
  motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
558
+ motor_sampling_probability=cehrgpt_args.motor_sampling_probability,
555
559
  ),
556
560
  train_dataset=processed_dataset["train"],
557
561
  eval_dataset=(