cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import os
5
+ from typing import Any, Dict, Iterable, Optional, Set, Union
6
+
7
+ import polars as pl
8
+ from datasets import Dataset
9
+
10
+
11
+ # Adapted from femr.ontology
12
+ def _get_all_codes_map(batch: Dict[str, Any]) -> Dict[str, Any]:
13
+ result = set()
14
+ for concept_ids in batch["concept_ids"]:
15
+ for concept_id in concept_ids:
16
+ if concept_id.isnumeric():
17
+ result.add(concept_id)
18
+ return {"unique_concept_ids": list(result)}
19
+
20
+
21
+ class Ontology:
22
+ def __init__(self, vocab_path: str):
23
+ """Create an Ontology from an Athena download and an optional meds Code Metadata structure.
24
+
25
+ NOTE: This is an expensive operation.
26
+ It is recommended to create an ontology once and then save/load it as necessary.
27
+ """
28
+ # Load from code metadata
29
+ self.parents_map: Dict[str, Set[str]] = collections.defaultdict(set)
30
+ self.concept_vocabulary_map: Dict[str, str] = collections.defaultdict(str)
31
+ self.concept_domain_map: Dict[str, str] = collections.defaultdict(str)
32
+
33
+ # Load from the athena path ...
34
+ concept = pl.scan_parquet(os.path.join(vocab_path, "concept/*parquet"))
35
+ vocabulary_id_col = pl.col("vocabulary_id")
36
+ concept_id_col = pl.col("concept_id").cast(pl.String)
37
+ domain_id_col = pl.col("domain_id").cast(pl.String)
38
+
39
+ processed_concepts = (
40
+ concept.select(
41
+ concept_id_col,
42
+ domain_id_col,
43
+ vocabulary_id_col,
44
+ pl.col("standard_concept").is_null(),
45
+ )
46
+ .collect()
47
+ .rows()
48
+ )
49
+
50
+ non_standard_concepts = set()
51
+
52
+ for concept_id, domain_id, vocabulary_id, is_non_standard in processed_concepts:
53
+ # We don't want to override code metadata
54
+ if concept_id not in self.concept_vocabulary_map:
55
+ self.concept_vocabulary_map[concept_id] = vocabulary_id
56
+
57
+ if concept_id not in self.concept_domain_map:
58
+ self.concept_domain_map[concept_id] = domain_id
59
+ # We don't want to override code metadata
60
+ if is_non_standard:
61
+ non_standard_concepts.add(concept_id)
62
+
63
+ relationship = pl.scan_parquet(
64
+ os.path.join(vocab_path, "concept_relationship/*parquet")
65
+ )
66
+ relationship_id = pl.col("relationship_id")
67
+ relationship = relationship.filter(
68
+ relationship_id == "Maps to",
69
+ pl.col("concept_id_1") != pl.col("concept_id_2"),
70
+ )
71
+ for concept_id_1, concept_id_2 in (
72
+ relationship.select(
73
+ pl.col("concept_id_1").cast(pl.String),
74
+ pl.col("concept_id_2").cast(pl.String),
75
+ )
76
+ .collect()
77
+ .rows()
78
+ ):
79
+ if concept_id_1 in non_standard_concepts:
80
+ self.parents_map[concept_id_1].add(concept_id_2)
81
+
82
+ ancestor = pl.scan_parquet(
83
+ os.path.join(vocab_path, "concept_ancestor/*parquet")
84
+ )
85
+ ancestor = ancestor.filter(pl.col("min_levels_of_separation") == 1)
86
+ for concept_id, parent_concept_id in (
87
+ ancestor.select(
88
+ pl.col("descendant_concept_id").cast(pl.String),
89
+ pl.col("ancestor_concept_id").cast(pl.String),
90
+ )
91
+ .collect()
92
+ .rows()
93
+ ):
94
+ self.parents_map[concept_id].add(parent_concept_id)
95
+ self.all_parents_map: Dict[str, Set[str]] = {}
96
+
97
+ def get_domain(self, concept_id: Union[str, int]) -> Optional[str]:
98
+ return self.concept_domain_map.get(str(concept_id), None)
99
+
100
+ def prune_to_dataset(
101
+ self,
102
+ dataset: Dataset,
103
+ remove_ontologies: Set[str] = set(),
104
+ num_proc: int = 4,
105
+ batch_size: int = 1024,
106
+ ) -> None:
107
+ mapped_dataset = dataset.map(
108
+ _get_all_codes_map,
109
+ batched=True,
110
+ batch_size=batch_size,
111
+ remove_columns=dataset.column_names,
112
+ num_proc=num_proc,
113
+ )
114
+ valid_concept_ids = set(mapped_dataset["unique_concept_ids"])
115
+ all_parents = set()
116
+ for concept_id in valid_concept_ids:
117
+ all_parents |= self.get_all_parents(concept_id)
118
+
119
+ def is_valid(c: str):
120
+ ontology = self.concept_vocabulary_map.get(c, "")
121
+ return (c in valid_concept_ids) or (
122
+ (ontology not in remove_ontologies) and (c in all_parents)
123
+ )
124
+
125
+ concept_ids = set(self.parents_map.keys())
126
+ for concept_id in concept_ids:
127
+ m: Any
128
+ if is_valid(concept_id):
129
+ for m in (self.parents_map, self.concept_vocabulary_map):
130
+ m[concept_id] = {a for a in m[concept_id] if is_valid(a)}
131
+ else:
132
+ for m in (self.parents_map, self.concept_vocabulary_map):
133
+ if concept_id in m:
134
+ del m[concept_id]
135
+
136
+ self.all_parents_map = {}
137
+
138
+ # Prime the pump
139
+ for concept_id in self.parents_map.keys():
140
+ self.get_all_parents(concept_id)
141
+
142
+ def get_parents(self, code: str) -> Iterable[str]:
143
+ """Get the parents for a given code."""
144
+ return self.parents_map.get(code, set())
145
+
146
+ def get_all_parents(self, code: str) -> Set[str]:
147
+ """Get all parents, including through the ontology."""
148
+ if code not in self.all_parents_map:
149
+ result = {code}
150
+ for parent in self.parents_map.get(code, set()):
151
+ result |= self.get_all_parents(parent)
152
+ self.all_parents_map[code] = result
153
+
154
+ return self.all_parents_map[code]
@@ -1,3 +1,4 @@
1
+ import glob
1
2
  import json
2
3
  import os
3
4
  import random
@@ -175,11 +176,6 @@ def model_init(
175
176
  model.config.class_weights = cehrgpt_args.class_weights
176
177
  LOG.info(f"Setting class_weights to {model.config.class_weights}")
177
178
 
178
- # Enable position embeddings when position embeddings are disabled in pre-training
179
- if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
180
- LOG.info(f"Enable the position_embeddings")
181
- model.cehrgpt.enable_position_embeddings()
182
-
183
179
  if model.config.max_position_embeddings < model_args.max_position_embeddings:
184
180
  LOG.info(
185
181
  f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
@@ -379,7 +375,6 @@ def main():
379
375
  SamplePackingCehrGptDataCollator,
380
376
  cehrgpt_args.max_tokens_per_batch,
381
377
  config.max_position_embeddings,
382
- add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
383
378
  )
384
379
  else:
385
380
  trainer_class = Trainer
@@ -406,8 +401,9 @@ def main():
406
401
  )
407
402
 
408
403
  if training_args.do_train:
404
+ output_dir = training_args.output_dir
409
405
  if cehrgpt_args.hyperparameter_tuning:
410
- training_args = perform_hyperparameter_search(
406
+ training_args, run_id = perform_hyperparameter_search(
411
407
  trainer_class,
412
408
  partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
413
409
  processed_dataset,
@@ -416,18 +412,28 @@ def main():
416
412
  model_args,
417
413
  cehrgpt_args,
418
414
  )
419
-
420
- if cehrgpt_args.retrain_with_full:
421
- # Always retrain with the full set when hyperparameter tuning is set to true
422
- retrain_with_full_set(
423
- trainer_class,
424
- model_args,
425
- training_args,
426
- cehrgpt_args,
427
- tokenizer,
428
- processed_dataset,
429
- data_collator,
415
+ # We enforce retraining if cehrgpt_args.hyperparameter_tuning_percentage < 1.0
416
+ cehrgpt_args.retrain_with_full |= (
417
+ cehrgpt_args.hyperparameter_tuning_percentage < 1.0
430
418
  )
419
+ output_dir = os.path.join(training_args.output_dir, f"run-{run_id}")
420
+
421
+ if cehrgpt_args.hyperparameter_tuning and not cehrgpt_args.retrain_with_full:
422
+ folders = glob.glob(os.path.join(output_dir, "checkpoint-*"))
423
+ if len(folders) == 0:
424
+ raise RuntimeError(
425
+ f"There must be a checkpoint folder under {output_dir}"
426
+ )
427
+ checkpoint_dir = folders[0]
428
+ LOG.info("Best trial checkpoint folder: %s", checkpoint_dir)
429
+ for file_name in os.listdir(checkpoint_dir):
430
+ try:
431
+ full_file_name = os.path.join(checkpoint_dir, file_name)
432
+ destination = os.path.join(training_args.output_dir, file_name)
433
+ if os.path.isfile(full_file_name):
434
+ shutil.copy2(full_file_name, destination)
435
+ except Exception as e:
436
+ LOG.error("Failed to copy %s: %s", file_name, str(e))
431
437
  else:
432
438
  # Initialize Trainer for final training on the combined train+val set
433
439
  trainer = trainer_class(
@@ -476,63 +482,6 @@ def main():
476
482
  do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
477
483
 
478
484
 
479
- def retrain_with_full_set(
480
- trainer_class,
481
- model_args: ModelArguments,
482
- training_args: TrainingArguments,
483
- cehrgpt_args: CehrGPTArguments,
484
- tokenizer: CehrGptTokenizer,
485
- dataset: DatasetDict,
486
- data_collator: CehrGptDataCollator,
487
- ) -> None:
488
- """
489
- Retrains a model on the full training and validation dataset for final performance evaluation.
490
-
491
- This function consolidates the training and validation datasets into a single
492
- dataset for final model training, updates the output directory for the final model,
493
- and disables evaluation during training. It resumes from the latest checkpoint if available,
494
- trains the model on the combined dataset, and saves the model along with training metrics
495
- and state information.
496
-
497
- Args:
498
- trainer_class: Trainer or its subclass
499
- model_args (ModelArguments): Model configuration and hyperparameters.
500
- training_args (TrainingArguments): Training configuration, including output directory,
501
- evaluation strategy, and other training parameters.
502
- cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
503
- tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
504
- dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
505
- data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
506
-
507
- Returns:
508
- None
509
- """
510
- # Initialize Trainer for final training on the combined train+val set
511
- full_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])
512
- training_args.output_dir = os.path.join(training_args.output_dir, "full")
513
- LOG.info(
514
- "Final output_dir for final_training_args.output_dir %s",
515
- training_args.output_dir,
516
- )
517
- Path(training_args.output_dir).mkdir(exist_ok=True)
518
- # Disable evaluation
519
- training_args.evaluation_strategy = "no"
520
- checkpoint = get_last_hf_checkpoint(training_args)
521
- final_trainer = trainer_class(
522
- model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
523
- data_collator=data_collator,
524
- args=training_args,
525
- train_dataset=full_dataset,
526
- tokenizer=tokenizer,
527
- )
528
- final_train_result = final_trainer.train(resume_from_checkpoint=checkpoint)
529
- final_trainer.save_model() # Saves the tokenizer too for easy upload
530
- metrics = final_train_result.metrics
531
- final_trainer.log_metrics("train", metrics)
532
- final_trainer.save_metrics("train", metrics)
533
- final_trainer.save_state()
534
-
535
-
536
485
  def do_predict(
537
486
  test_dataloader: DataLoader,
538
487
  model_args: ModelArguments,
@@ -652,9 +601,6 @@ def load_lora_model(
652
601
  # Enable include_values when include_values is set to be False during pre-training
653
602
  if model_args.include_values and not model.cehrgpt.include_values:
654
603
  model.cehrgpt.include_values = True
655
- # Enable position embeddings when position embeddings are disabled in pre-training
656
- if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
657
- model.cehrgpt.exclude_position_ids = False
658
604
  if cehrgpt_args.expand_tokenizer:
659
605
  tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
660
606
  # Expand tokenizer to adapt to the finetuning dataset
@@ -2,7 +2,9 @@ import os
2
2
  from functools import partial
3
3
  from typing import Optional, Union
4
4
 
5
+ import datasets
5
6
  import numpy as np
7
+ import pandas as pd
6
8
  import torch
7
9
  import torch.distributed as dist
8
10
  from cehrbert.data_generators.hf_data_generator.meds_utils import (
@@ -20,7 +22,7 @@ from cehrbert.runners.runner_util import (
20
22
  load_parquet_as_dataset,
21
23
  )
22
24
  from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
23
- from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
25
+ from transformers import EarlyStoppingCallback, Trainer, set_seed
24
26
  from transformers.trainer_utils import is_main_process
25
27
  from transformers.utils import is_flash_attn_2_available, logging
26
28
 
@@ -34,6 +36,7 @@ from cehrgpt.models.config import CEHRGPTConfig
34
36
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
35
37
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
36
38
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
39
+ from cehrgpt.omop.ontology import Ontology
37
40
  from cehrgpt.runners.data_utils import get_torch_dtype
38
41
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
39
42
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
@@ -70,68 +73,64 @@ def load_and_create_tokenizer(
70
73
  data_args: DataTrainingArguments,
71
74
  model_args: ModelArguments,
72
75
  cehrgpt_args: CehrGPTArguments,
73
- dataset: Optional[Union[Dataset, DatasetDict]] = None,
76
+ dataset: Union[Dataset, DatasetDict],
74
77
  ) -> CehrGptTokenizer:
75
78
 
76
- concept_name_mapping = {}
77
- allowed_motor_codes = list()
78
- if cehrgpt_args.concept_dir:
79
- import pandas as pd
80
- from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
81
- from meds.schema import death_code
82
-
83
- LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
84
- concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
85
- LOG.info(
86
- "Creating concept name mapping and motor_time_to_event_codes from disk at %s",
87
- cehrgpt_args.concept_dir,
88
- )
89
- for row in concept_pd.itertuples():
90
- concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
91
- row, "concept_name"
92
- )
93
- if (
94
- cehrgpt_args.include_motor_time_to_event
95
- and getattr(row, "domain_id")
96
- in ["Condition", "Procedure", "Drug", "Visit"]
97
- and getattr(row, "standard_concept") == "S"
98
- ):
99
- allowed_motor_codes.append(str(getattr(row, "concept_id")))
100
- LOG.info(
101
- "Adding death codes for MOTOR TTE predictions: %s",
102
- [DEATH_TOKEN, death_code],
103
- )
104
- allowed_motor_codes.extend([DEATH_TOKEN, death_code])
105
79
  # Try to load the pretrained tokenizer
106
80
  tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
107
- try:
108
- tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
109
- except Exception as e:
110
- LOG.warning(e)
111
- if dataset is None:
81
+ if not tokenizer_exists(tokenizer_abspath):
82
+ if cehrgpt_args.include_motor_time_to_event and not cehrgpt_args.vocab_dir:
112
83
  raise RuntimeError(
113
- f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
114
- f"Tried to create the tokenizer, however the dataset is not provided."
84
+ "motor_vocab_dir must be specified if include_motor_time_to_event is True"
85
+ )
86
+ ontology: Optional[Ontology] = None
87
+ concept_name_mapping = {}
88
+ if cehrgpt_args.vocab_dir:
89
+ LOG.info("Loading concept data from disk at %s", cehrgpt_args.vocab_dir)
90
+ concept_pd = pd.read_parquet(
91
+ os.path.join(cehrgpt_args.vocab_dir, "concept")
115
92
  )
93
+ for row in concept_pd.itertuples():
94
+ concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
95
+ row, "concept_name"
96
+ )
97
+
98
+ if cehrgpt_args.motor_use_ontology:
99
+ LOG.info("Creating ontology for MOTOR TTE predictions")
100
+ ontology = Ontology(cehrgpt_args.vocab_dir)
101
+ train_val_dataset = datasets.concatenate_datasets(
102
+ [dataset["train"], dataset["validation"]]
103
+ )
104
+ ontology.prune_to_dataset(
105
+ train_val_dataset,
106
+ num_proc=data_args.preprocessing_num_workers,
107
+ remove_ontologies={"SPL", "HemOnc", "LOINC"},
108
+ )
109
+
116
110
  LOG.info("Started training the tokenizer ...")
111
+ train_val_dataset = datasets.concatenate_datasets(
112
+ [dataset["train"], dataset["validation"]]
113
+ )
117
114
  tokenizer = CehrGptTokenizer.train_tokenizer(
118
- dataset,
115
+ train_val_dataset,
119
116
  concept_name_mapping,
120
117
  data_args,
121
118
  PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
122
- allowed_motor_codes if cehrgpt_args.include_motor_time_to_event else None,
123
- (
119
+ num_motor_tasks=(
124
120
  cehrgpt_args.num_motor_tasks
125
121
  if cehrgpt_args.include_motor_time_to_event
126
122
  else None
127
123
  ),
128
124
  apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
129
125
  min_prevalence=cehrgpt_args.min_prevalence,
126
+ ontology=ontology,
130
127
  )
131
128
  LOG.info("Finished training the tokenizer ...")
132
129
  tokenizer.save_pretrained(tokenizer_abspath)
133
130
  LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
134
-
131
+ else:
132
+ LOG.info("The tokenizer exists and will be loaded from %s", tokenizer_abspath)
133
+ tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
135
134
  return tokenizer
136
135
 
137
136
 
@@ -187,7 +186,10 @@ def load_and_create_model(
187
186
 
188
187
  model_args_cehrgpt = model_args.as_dict()
189
188
  model_args_cehrgpt.pop("attn_implementation")
189
+ # CEHR-GPT does not support this anymore
190
+ model_args_cehrgpt.pop("exclude_position_ids")
190
191
  model_config = CEHRGPTConfig(
192
+ activation_function=cehrgpt_args.activation_function,
191
193
  vocab_size=tokenizer.vocab_size,
192
194
  value_vocab_size=tokenizer.value_vocab_size,
193
195
  time_token_vocab_size=tokenizer.time_token_vocab_size,
@@ -207,6 +209,7 @@ def load_and_create_model(
207
209
  n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
208
210
  use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
209
211
  pretrained_embedding_dim=pretrained_embedding_dim,
212
+ apply_rotary=cehrgpt_args.apply_rotary,
210
213
  sample_packing_max_positions=(
211
214
  cehrgpt_args.max_tokens_per_batch
212
215
  if cehrgpt_args.sample_packing
@@ -217,6 +220,8 @@ def load_and_create_model(
217
220
  motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
218
221
  motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
219
222
  ve_token_id=tokenizer.ve_token_id,
223
+ n_inner=cehrgpt_args.inner_dim,
224
+ decoder_mlp=cehrgpt_args.decoder_mlp,
220
225
  **model_args_cehrgpt,
221
226
  )
222
227
 
@@ -235,7 +240,6 @@ def load_and_create_model(
235
240
 
236
241
  def main():
237
242
  cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
238
-
239
243
  if cehrgpt_args.sample_packing and data_args.streaming:
240
244
  raise RuntimeError(
241
245
  f"sample_packing is not supported when streaming is enabled, please set streaming to False"
@@ -530,7 +534,6 @@ def main():
530
534
  SamplePackingCehrGptDataCollator,
531
535
  cehrgpt_args.max_tokens_per_batch,
532
536
  model_args.max_position_embeddings,
533
- add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
534
537
  )
535
538
  else:
536
539
  trainer_class = Trainer
@@ -552,6 +555,7 @@ def main():
552
555
  include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
553
556
  motor_tte_vocab_size=model.config.motor_tte_vocab_size,
554
557
  motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
558
+ motor_sampling_probability=cehrgpt_args.motor_sampling_probability,
555
559
  ),
556
560
  train_dataset=processed_dataset["train"],
557
561
  eval_dataset=(
@@ -1,5 +1,7 @@
1
1
  import dataclasses
2
- from typing import List, Optional
2
+ from typing import List, Literal, Optional
3
+
4
+ from cehrgpt.models.gpt2 import ACT2FN
3
5
 
4
6
 
5
7
  @dataclasses.dataclass
@@ -12,6 +14,14 @@ class CehrGPTArguments:
12
14
  "help": "The path to the tokenized dataset created for the full population"
13
15
  },
14
16
  )
17
+ activation_function: Literal[tuple(ACT2FN.keys())] = dataclasses.field(
18
+ default="gelu_new",
19
+ metadata={"help": "The activation function to use"},
20
+ )
21
+ decoder_mlp: Literal["GPT2MLP", "LlamaMLP"] = dataclasses.field(
22
+ default="GPT2MLP",
23
+ metadata={"help": "The decoder MLP architecture"},
24
+ )
15
25
  include_inpatient_hour_token: Optional[bool] = dataclasses.field(
16
26
  default=True,
17
27
  metadata={"help": "Include inpatient hour token"},
@@ -54,6 +64,14 @@ class CehrGPTArguments:
54
64
  default=128,
55
65
  metadata={"help": "The number of examples from the training set."},
56
66
  )
67
+ hyperparameter_tuning: Optional[bool] = dataclasses.field(
68
+ default=False,
69
+ metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
70
+ )
71
+ hyperparameter_tuning_is_grid: Optional[bool] = dataclasses.field(
72
+ default=True,
73
+ metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
74
+ )
57
75
  hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
58
76
  default=0.1,
59
77
  metadata={
@@ -66,10 +84,6 @@ class CehrGPTArguments:
66
84
  "help": "The number of trails will be use for hyperparameter tuning."
67
85
  },
68
86
  )
69
- hyperparameter_tuning: Optional[bool] = dataclasses.field(
70
- default=False,
71
- metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
72
- )
73
87
  hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
74
88
  default_factory=lambda: [4, 8, 16],
75
89
  metadata={"help": "Hyperparameter search batch sizes"},
@@ -78,29 +92,13 @@ class CehrGPTArguments:
78
92
  default_factory=lambda: [10],
79
93
  metadata={"help": "Hyperparameter search num_train_epochs"},
80
94
  )
81
- lr_low: Optional[float] = dataclasses.field(
82
- default=1e-5,
83
- metadata={
84
- "help": "The lower bound of the learning rate range for hyperparameter tuning."
85
- },
86
- )
87
- lr_high: Optional[float] = dataclasses.field(
88
- default=5e-5,
89
- metadata={
90
- "help": "The upper bound of the learning rate range for hyperparameter tuning."
91
- },
92
- )
93
- weight_decays_low: Optional[float] = dataclasses.field(
94
- default=1e-3,
95
- metadata={
96
- "help": "The lower bound of the weight decays range for hyperparameter tuning."
97
- },
95
+ hyperparameter_learning_rates: Optional[List[int]] = dataclasses.field(
96
+ default_factory=lambda: [1e-5],
97
+ metadata={"help": "Hyperparameter search learning rates"},
98
98
  )
99
- weight_decays_high: Optional[float] = dataclasses.field(
100
- default=1e-2,
101
- metadata={
102
- "help": "The upper bound of the weight decays range for hyperparameter tuning."
103
- },
99
+ hyperparameter_weight_decays: Optional[List[int]] = dataclasses.field(
100
+ default_factory=lambda: [1e-2],
101
+ metadata={"help": "Hyperparameter search learning rates"},
104
102
  )
105
103
  causal_sfm: Optional[bool] = dataclasses.field(
106
104
  default=False,
@@ -168,6 +166,16 @@ class CehrGPTArguments:
168
166
  "help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
169
167
  },
170
168
  )
169
+ inner_dim: Optional[int] = dataclasses.field(
170
+ default=None,
171
+ metadata={"help": "The dimensionality of the hidden layer"},
172
+ )
173
+ apply_rotary: Optional[bool] = dataclasses.field(
174
+ default=False,
175
+ metadata={
176
+ "help": "A flag to indicate whether we want to use rotary encoder layers"
177
+ },
178
+ )
171
179
  sample_packing: Optional[bool] = dataclasses.field(
172
180
  default=False,
173
181
  metadata={
@@ -177,12 +185,6 @@ class CehrGPTArguments:
177
185
  max_tokens_per_batch: int = dataclasses.field(
178
186
  default=16384, metadata={"help": "Maximum number of tokens in each batch"}
179
187
  )
180
- add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
181
- default=False,
182
- metadata={
183
- "help": "A flag to indicate whether we want to add end token in sample packing"
184
- },
185
- )
186
188
  include_motor_time_to_event: Optional[bool] = dataclasses.field(
187
189
  default=False,
188
190
  metadata={
@@ -203,7 +205,17 @@ class CehrGPTArguments:
203
205
  "help": "The number of times each motor_num_time_pieces piece has to be"
204
206
  },
205
207
  )
206
- concept_dir: Optional[str] = dataclasses.field(
208
+ motor_use_ontology: Optional[bool] = dataclasses.field(
209
+ default=False,
210
+ metadata={
211
+ "help": "A flag to indicate whether we want to use motor_use_ontology"
212
+ },
213
+ )
214
+ motor_sampling_probability: Optional[float] = dataclasses.field(
215
+ default=0.0,
216
+ metadata={"help": "A flag to indicate whether we want to use sample packing"},
217
+ )
218
+ vocab_dir: Optional[str] = dataclasses.field(
207
219
  default=None,
208
220
  metadata={"help": "The directory where the concept data is stored."},
209
221
  )