cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.4.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
cehrgpt/omop/ontology.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import collections
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, Iterable, Optional, Set, Union
|
6
|
+
|
7
|
+
import polars as pl
|
8
|
+
from datasets import Dataset
|
9
|
+
|
10
|
+
|
11
|
+
# Adapted from femr.ontology
|
12
|
+
def _get_all_codes_map(batch: Dict[str, Any]) -> Dict[str, Any]:
|
13
|
+
result = set()
|
14
|
+
for concept_ids in batch["concept_ids"]:
|
15
|
+
for concept_id in concept_ids:
|
16
|
+
if concept_id.isnumeric():
|
17
|
+
result.add(concept_id)
|
18
|
+
return {"unique_concept_ids": list(result)}
|
19
|
+
|
20
|
+
|
21
|
+
class Ontology:
|
22
|
+
def __init__(self, vocab_path: str):
|
23
|
+
"""Create an Ontology from an Athena download and an optional meds Code Metadata structure.
|
24
|
+
|
25
|
+
NOTE: This is an expensive operation.
|
26
|
+
It is recommended to create an ontology once and then save/load it as necessary.
|
27
|
+
"""
|
28
|
+
# Load from code metadata
|
29
|
+
self.parents_map: Dict[str, Set[str]] = collections.defaultdict(set)
|
30
|
+
self.concept_vocabulary_map: Dict[str, str] = collections.defaultdict(str)
|
31
|
+
self.concept_domain_map: Dict[str, str] = collections.defaultdict(str)
|
32
|
+
|
33
|
+
# Load from the athena path ...
|
34
|
+
concept = pl.scan_parquet(os.path.join(vocab_path, "concept/*parquet"))
|
35
|
+
vocabulary_id_col = pl.col("vocabulary_id")
|
36
|
+
concept_id_col = pl.col("concept_id").cast(pl.String)
|
37
|
+
domain_id_col = pl.col("domain_id").cast(pl.String)
|
38
|
+
|
39
|
+
processed_concepts = (
|
40
|
+
concept.select(
|
41
|
+
concept_id_col,
|
42
|
+
domain_id_col,
|
43
|
+
vocabulary_id_col,
|
44
|
+
pl.col("standard_concept").is_null(),
|
45
|
+
)
|
46
|
+
.collect()
|
47
|
+
.rows()
|
48
|
+
)
|
49
|
+
|
50
|
+
non_standard_concepts = set()
|
51
|
+
|
52
|
+
for concept_id, domain_id, vocabulary_id, is_non_standard in processed_concepts:
|
53
|
+
# We don't want to override code metadata
|
54
|
+
if concept_id not in self.concept_vocabulary_map:
|
55
|
+
self.concept_vocabulary_map[concept_id] = vocabulary_id
|
56
|
+
|
57
|
+
if concept_id not in self.concept_domain_map:
|
58
|
+
self.concept_domain_map[concept_id] = domain_id
|
59
|
+
# We don't want to override code metadata
|
60
|
+
if is_non_standard:
|
61
|
+
non_standard_concepts.add(concept_id)
|
62
|
+
|
63
|
+
relationship = pl.scan_parquet(
|
64
|
+
os.path.join(vocab_path, "concept_relationship/*parquet")
|
65
|
+
)
|
66
|
+
relationship_id = pl.col("relationship_id")
|
67
|
+
relationship = relationship.filter(
|
68
|
+
relationship_id == "Maps to",
|
69
|
+
pl.col("concept_id_1") != pl.col("concept_id_2"),
|
70
|
+
)
|
71
|
+
for concept_id_1, concept_id_2 in (
|
72
|
+
relationship.select(
|
73
|
+
pl.col("concept_id_1").cast(pl.String),
|
74
|
+
pl.col("concept_id_2").cast(pl.String),
|
75
|
+
)
|
76
|
+
.collect()
|
77
|
+
.rows()
|
78
|
+
):
|
79
|
+
if concept_id_1 in non_standard_concepts:
|
80
|
+
self.parents_map[concept_id_1].add(concept_id_2)
|
81
|
+
|
82
|
+
ancestor = pl.scan_parquet(
|
83
|
+
os.path.join(vocab_path, "concept_ancestor/*parquet")
|
84
|
+
)
|
85
|
+
ancestor = ancestor.filter(pl.col("min_levels_of_separation") == 1)
|
86
|
+
for concept_id, parent_concept_id in (
|
87
|
+
ancestor.select(
|
88
|
+
pl.col("descendant_concept_id").cast(pl.String),
|
89
|
+
pl.col("ancestor_concept_id").cast(pl.String),
|
90
|
+
)
|
91
|
+
.collect()
|
92
|
+
.rows()
|
93
|
+
):
|
94
|
+
self.parents_map[concept_id].add(parent_concept_id)
|
95
|
+
self.all_parents_map: Dict[str, Set[str]] = {}
|
96
|
+
|
97
|
+
def get_domain(self, concept_id: Union[str, int]) -> Optional[str]:
|
98
|
+
return self.concept_domain_map.get(str(concept_id), None)
|
99
|
+
|
100
|
+
def prune_to_dataset(
|
101
|
+
self,
|
102
|
+
dataset: Dataset,
|
103
|
+
remove_ontologies: Set[str] = set(),
|
104
|
+
num_proc: int = 4,
|
105
|
+
batch_size: int = 1024,
|
106
|
+
) -> None:
|
107
|
+
mapped_dataset = dataset.map(
|
108
|
+
_get_all_codes_map,
|
109
|
+
batched=True,
|
110
|
+
batch_size=batch_size,
|
111
|
+
remove_columns=dataset.column_names,
|
112
|
+
num_proc=num_proc,
|
113
|
+
)
|
114
|
+
valid_concept_ids = set(mapped_dataset["unique_concept_ids"])
|
115
|
+
all_parents = set()
|
116
|
+
for concept_id in valid_concept_ids:
|
117
|
+
all_parents |= self.get_all_parents(concept_id)
|
118
|
+
|
119
|
+
def is_valid(c: str):
|
120
|
+
ontology = self.concept_vocabulary_map.get(c, "")
|
121
|
+
return (c in valid_concept_ids) or (
|
122
|
+
(ontology not in remove_ontologies) and (c in all_parents)
|
123
|
+
)
|
124
|
+
|
125
|
+
concept_ids = set(self.parents_map.keys())
|
126
|
+
for concept_id in concept_ids:
|
127
|
+
m: Any
|
128
|
+
if is_valid(concept_id):
|
129
|
+
for m in (self.parents_map, self.concept_vocabulary_map):
|
130
|
+
m[concept_id] = {a for a in m[concept_id] if is_valid(a)}
|
131
|
+
else:
|
132
|
+
for m in (self.parents_map, self.concept_vocabulary_map):
|
133
|
+
if concept_id in m:
|
134
|
+
del m[concept_id]
|
135
|
+
|
136
|
+
self.all_parents_map = {}
|
137
|
+
|
138
|
+
# Prime the pump
|
139
|
+
for concept_id in self.parents_map.keys():
|
140
|
+
self.get_all_parents(concept_id)
|
141
|
+
|
142
|
+
def get_parents(self, code: str) -> Iterable[str]:
|
143
|
+
"""Get the parents for a given code."""
|
144
|
+
return self.parents_map.get(code, set())
|
145
|
+
|
146
|
+
def get_all_parents(self, code: str) -> Set[str]:
|
147
|
+
"""Get all parents, including through the ontology."""
|
148
|
+
if code not in self.all_parents_map:
|
149
|
+
result = {code}
|
150
|
+
for parent in self.parents_map.get(code, set()):
|
151
|
+
result |= self.get_all_parents(parent)
|
152
|
+
self.all_parents_map[code] = result
|
153
|
+
|
154
|
+
return self.all_parents_map[code]
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import glob
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import random
|
@@ -175,11 +176,6 @@ def model_init(
|
|
175
176
|
model.config.class_weights = cehrgpt_args.class_weights
|
176
177
|
LOG.info(f"Setting class_weights to {model.config.class_weights}")
|
177
178
|
|
178
|
-
# Enable position embeddings when position embeddings are disabled in pre-training
|
179
|
-
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
180
|
-
LOG.info(f"Enable the position_embeddings")
|
181
|
-
model.cehrgpt.enable_position_embeddings()
|
182
|
-
|
183
179
|
if model.config.max_position_embeddings < model_args.max_position_embeddings:
|
184
180
|
LOG.info(
|
185
181
|
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
@@ -379,7 +375,6 @@ def main():
|
|
379
375
|
SamplePackingCehrGptDataCollator,
|
380
376
|
cehrgpt_args.max_tokens_per_batch,
|
381
377
|
config.max_position_embeddings,
|
382
|
-
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
383
378
|
)
|
384
379
|
else:
|
385
380
|
trainer_class = Trainer
|
@@ -406,8 +401,9 @@ def main():
|
|
406
401
|
)
|
407
402
|
|
408
403
|
if training_args.do_train:
|
404
|
+
output_dir = training_args.output_dir
|
409
405
|
if cehrgpt_args.hyperparameter_tuning:
|
410
|
-
training_args = perform_hyperparameter_search(
|
406
|
+
training_args, run_id = perform_hyperparameter_search(
|
411
407
|
trainer_class,
|
412
408
|
partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
|
413
409
|
processed_dataset,
|
@@ -416,18 +412,28 @@ def main():
|
|
416
412
|
model_args,
|
417
413
|
cehrgpt_args,
|
418
414
|
)
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
retrain_with_full_set(
|
423
|
-
trainer_class,
|
424
|
-
model_args,
|
425
|
-
training_args,
|
426
|
-
cehrgpt_args,
|
427
|
-
tokenizer,
|
428
|
-
processed_dataset,
|
429
|
-
data_collator,
|
415
|
+
# We enforce retraining if cehrgpt_args.hyperparameter_tuning_percentage < 1.0
|
416
|
+
cehrgpt_args.retrain_with_full |= (
|
417
|
+
cehrgpt_args.hyperparameter_tuning_percentage < 1.0
|
430
418
|
)
|
419
|
+
output_dir = os.path.join(training_args.output_dir, f"run-{run_id}")
|
420
|
+
|
421
|
+
if cehrgpt_args.hyperparameter_tuning and not cehrgpt_args.retrain_with_full:
|
422
|
+
folders = glob.glob(os.path.join(output_dir, "checkpoint-*"))
|
423
|
+
if len(folders) == 0:
|
424
|
+
raise RuntimeError(
|
425
|
+
f"There must be a checkpoint folder under {output_dir}"
|
426
|
+
)
|
427
|
+
checkpoint_dir = folders[0]
|
428
|
+
LOG.info("Best trial checkpoint folder: %s", checkpoint_dir)
|
429
|
+
for file_name in os.listdir(checkpoint_dir):
|
430
|
+
try:
|
431
|
+
full_file_name = os.path.join(checkpoint_dir, file_name)
|
432
|
+
destination = os.path.join(training_args.output_dir, file_name)
|
433
|
+
if os.path.isfile(full_file_name):
|
434
|
+
shutil.copy2(full_file_name, destination)
|
435
|
+
except Exception as e:
|
436
|
+
LOG.error("Failed to copy %s: %s", file_name, str(e))
|
431
437
|
else:
|
432
438
|
# Initialize Trainer for final training on the combined train+val set
|
433
439
|
trainer = trainer_class(
|
@@ -476,63 +482,6 @@ def main():
|
|
476
482
|
do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
|
477
483
|
|
478
484
|
|
479
|
-
def retrain_with_full_set(
|
480
|
-
trainer_class,
|
481
|
-
model_args: ModelArguments,
|
482
|
-
training_args: TrainingArguments,
|
483
|
-
cehrgpt_args: CehrGPTArguments,
|
484
|
-
tokenizer: CehrGptTokenizer,
|
485
|
-
dataset: DatasetDict,
|
486
|
-
data_collator: CehrGptDataCollator,
|
487
|
-
) -> None:
|
488
|
-
"""
|
489
|
-
Retrains a model on the full training and validation dataset for final performance evaluation.
|
490
|
-
|
491
|
-
This function consolidates the training and validation datasets into a single
|
492
|
-
dataset for final model training, updates the output directory for the final model,
|
493
|
-
and disables evaluation during training. It resumes from the latest checkpoint if available,
|
494
|
-
trains the model on the combined dataset, and saves the model along with training metrics
|
495
|
-
and state information.
|
496
|
-
|
497
|
-
Args:
|
498
|
-
trainer_class: Trainer or its subclass
|
499
|
-
model_args (ModelArguments): Model configuration and hyperparameters.
|
500
|
-
training_args (TrainingArguments): Training configuration, including output directory,
|
501
|
-
evaluation strategy, and other training parameters.
|
502
|
-
cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
|
503
|
-
tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
|
504
|
-
dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
|
505
|
-
data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
|
506
|
-
|
507
|
-
Returns:
|
508
|
-
None
|
509
|
-
"""
|
510
|
-
# Initialize Trainer for final training on the combined train+val set
|
511
|
-
full_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])
|
512
|
-
training_args.output_dir = os.path.join(training_args.output_dir, "full")
|
513
|
-
LOG.info(
|
514
|
-
"Final output_dir for final_training_args.output_dir %s",
|
515
|
-
training_args.output_dir,
|
516
|
-
)
|
517
|
-
Path(training_args.output_dir).mkdir(exist_ok=True)
|
518
|
-
# Disable evaluation
|
519
|
-
training_args.evaluation_strategy = "no"
|
520
|
-
checkpoint = get_last_hf_checkpoint(training_args)
|
521
|
-
final_trainer = trainer_class(
|
522
|
-
model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
|
523
|
-
data_collator=data_collator,
|
524
|
-
args=training_args,
|
525
|
-
train_dataset=full_dataset,
|
526
|
-
tokenizer=tokenizer,
|
527
|
-
)
|
528
|
-
final_train_result = final_trainer.train(resume_from_checkpoint=checkpoint)
|
529
|
-
final_trainer.save_model() # Saves the tokenizer too for easy upload
|
530
|
-
metrics = final_train_result.metrics
|
531
|
-
final_trainer.log_metrics("train", metrics)
|
532
|
-
final_trainer.save_metrics("train", metrics)
|
533
|
-
final_trainer.save_state()
|
534
|
-
|
535
|
-
|
536
485
|
def do_predict(
|
537
486
|
test_dataloader: DataLoader,
|
538
487
|
model_args: ModelArguments,
|
@@ -652,9 +601,6 @@ def load_lora_model(
|
|
652
601
|
# Enable include_values when include_values is set to be False during pre-training
|
653
602
|
if model_args.include_values and not model.cehrgpt.include_values:
|
654
603
|
model.cehrgpt.include_values = True
|
655
|
-
# Enable position embeddings when position embeddings are disabled in pre-training
|
656
|
-
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
657
|
-
model.cehrgpt.exclude_position_ids = False
|
658
604
|
if cehrgpt_args.expand_tokenizer:
|
659
605
|
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
660
606
|
# Expand tokenizer to adapt to the finetuning dataset
|
@@ -2,7 +2,9 @@ import os
|
|
2
2
|
from functools import partial
|
3
3
|
from typing import Optional, Union
|
4
4
|
|
5
|
+
import datasets
|
5
6
|
import numpy as np
|
7
|
+
import pandas as pd
|
6
8
|
import torch
|
7
9
|
import torch.distributed as dist
|
8
10
|
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
@@ -20,7 +22,7 @@ from cehrbert.runners.runner_util import (
|
|
20
22
|
load_parquet_as_dataset,
|
21
23
|
)
|
22
24
|
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
|
23
|
-
from transformers import EarlyStoppingCallback, Trainer,
|
25
|
+
from transformers import EarlyStoppingCallback, Trainer, set_seed
|
24
26
|
from transformers.trainer_utils import is_main_process
|
25
27
|
from transformers.utils import is_flash_attn_2_available, logging
|
26
28
|
|
@@ -34,6 +36,7 @@ from cehrgpt.models.config import CEHRGPTConfig
|
|
34
36
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
35
37
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
36
38
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
39
|
+
from cehrgpt.omop.ontology import Ontology
|
37
40
|
from cehrgpt.runners.data_utils import get_torch_dtype
|
38
41
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
39
42
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
@@ -70,68 +73,64 @@ def load_and_create_tokenizer(
|
|
70
73
|
data_args: DataTrainingArguments,
|
71
74
|
model_args: ModelArguments,
|
72
75
|
cehrgpt_args: CehrGPTArguments,
|
73
|
-
dataset:
|
76
|
+
dataset: Union[Dataset, DatasetDict],
|
74
77
|
) -> CehrGptTokenizer:
|
75
78
|
|
76
|
-
concept_name_mapping = {}
|
77
|
-
allowed_motor_codes = list()
|
78
|
-
if cehrgpt_args.concept_dir:
|
79
|
-
import pandas as pd
|
80
|
-
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
81
|
-
from meds.schema import death_code
|
82
|
-
|
83
|
-
LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
|
84
|
-
concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
|
85
|
-
LOG.info(
|
86
|
-
"Creating concept name mapping and motor_time_to_event_codes from disk at %s",
|
87
|
-
cehrgpt_args.concept_dir,
|
88
|
-
)
|
89
|
-
for row in concept_pd.itertuples():
|
90
|
-
concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
|
91
|
-
row, "concept_name"
|
92
|
-
)
|
93
|
-
if (
|
94
|
-
cehrgpt_args.include_motor_time_to_event
|
95
|
-
and getattr(row, "domain_id")
|
96
|
-
in ["Condition", "Procedure", "Drug", "Visit"]
|
97
|
-
and getattr(row, "standard_concept") == "S"
|
98
|
-
):
|
99
|
-
allowed_motor_codes.append(str(getattr(row, "concept_id")))
|
100
|
-
LOG.info(
|
101
|
-
"Adding death codes for MOTOR TTE predictions: %s",
|
102
|
-
[DEATH_TOKEN, death_code],
|
103
|
-
)
|
104
|
-
allowed_motor_codes.extend([DEATH_TOKEN, death_code])
|
105
79
|
# Try to load the pretrained tokenizer
|
106
80
|
tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
|
107
|
-
|
108
|
-
|
109
|
-
except Exception as e:
|
110
|
-
LOG.warning(e)
|
111
|
-
if dataset is None:
|
81
|
+
if not tokenizer_exists(tokenizer_abspath):
|
82
|
+
if cehrgpt_args.include_motor_time_to_event and not cehrgpt_args.vocab_dir:
|
112
83
|
raise RuntimeError(
|
113
|
-
|
114
|
-
|
84
|
+
"motor_vocab_dir must be specified if include_motor_time_to_event is True"
|
85
|
+
)
|
86
|
+
ontology: Optional[Ontology] = None
|
87
|
+
concept_name_mapping = {}
|
88
|
+
if cehrgpt_args.vocab_dir:
|
89
|
+
LOG.info("Loading concept data from disk at %s", cehrgpt_args.vocab_dir)
|
90
|
+
concept_pd = pd.read_parquet(
|
91
|
+
os.path.join(cehrgpt_args.vocab_dir, "concept")
|
115
92
|
)
|
93
|
+
for row in concept_pd.itertuples():
|
94
|
+
concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
|
95
|
+
row, "concept_name"
|
96
|
+
)
|
97
|
+
|
98
|
+
if cehrgpt_args.motor_use_ontology:
|
99
|
+
LOG.info("Creating ontology for MOTOR TTE predictions")
|
100
|
+
ontology = Ontology(cehrgpt_args.vocab_dir)
|
101
|
+
train_val_dataset = datasets.concatenate_datasets(
|
102
|
+
[dataset["train"], dataset["validation"]]
|
103
|
+
)
|
104
|
+
ontology.prune_to_dataset(
|
105
|
+
train_val_dataset,
|
106
|
+
num_proc=data_args.preprocessing_num_workers,
|
107
|
+
remove_ontologies={"SPL", "HemOnc", "LOINC"},
|
108
|
+
)
|
109
|
+
|
116
110
|
LOG.info("Started training the tokenizer ...")
|
111
|
+
train_val_dataset = datasets.concatenate_datasets(
|
112
|
+
[dataset["train"], dataset["validation"]]
|
113
|
+
)
|
117
114
|
tokenizer = CehrGptTokenizer.train_tokenizer(
|
118
|
-
|
115
|
+
train_val_dataset,
|
119
116
|
concept_name_mapping,
|
120
117
|
data_args,
|
121
118
|
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
122
|
-
|
123
|
-
(
|
119
|
+
num_motor_tasks=(
|
124
120
|
cehrgpt_args.num_motor_tasks
|
125
121
|
if cehrgpt_args.include_motor_time_to_event
|
126
122
|
else None
|
127
123
|
),
|
128
124
|
apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
|
129
125
|
min_prevalence=cehrgpt_args.min_prevalence,
|
126
|
+
ontology=ontology,
|
130
127
|
)
|
131
128
|
LOG.info("Finished training the tokenizer ...")
|
132
129
|
tokenizer.save_pretrained(tokenizer_abspath)
|
133
130
|
LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
|
134
|
-
|
131
|
+
else:
|
132
|
+
LOG.info("The tokenizer exists and will be loaded from %s", tokenizer_abspath)
|
133
|
+
tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
|
135
134
|
return tokenizer
|
136
135
|
|
137
136
|
|
@@ -187,7 +186,10 @@ def load_and_create_model(
|
|
187
186
|
|
188
187
|
model_args_cehrgpt = model_args.as_dict()
|
189
188
|
model_args_cehrgpt.pop("attn_implementation")
|
189
|
+
# CEHR-GPT does not support this anymore
|
190
|
+
model_args_cehrgpt.pop("exclude_position_ids")
|
190
191
|
model_config = CEHRGPTConfig(
|
192
|
+
activation_function=cehrgpt_args.activation_function,
|
191
193
|
vocab_size=tokenizer.vocab_size,
|
192
194
|
value_vocab_size=tokenizer.value_vocab_size,
|
193
195
|
time_token_vocab_size=tokenizer.time_token_vocab_size,
|
@@ -207,6 +209,7 @@ def load_and_create_model(
|
|
207
209
|
n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
|
208
210
|
use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
|
209
211
|
pretrained_embedding_dim=pretrained_embedding_dim,
|
212
|
+
apply_rotary=cehrgpt_args.apply_rotary,
|
210
213
|
sample_packing_max_positions=(
|
211
214
|
cehrgpt_args.max_tokens_per_batch
|
212
215
|
if cehrgpt_args.sample_packing
|
@@ -217,6 +220,8 @@ def load_and_create_model(
|
|
217
220
|
motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
|
218
221
|
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
219
222
|
ve_token_id=tokenizer.ve_token_id,
|
223
|
+
n_inner=cehrgpt_args.inner_dim,
|
224
|
+
decoder_mlp=cehrgpt_args.decoder_mlp,
|
220
225
|
**model_args_cehrgpt,
|
221
226
|
)
|
222
227
|
|
@@ -235,7 +240,6 @@ def load_and_create_model(
|
|
235
240
|
|
236
241
|
def main():
|
237
242
|
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
238
|
-
|
239
243
|
if cehrgpt_args.sample_packing and data_args.streaming:
|
240
244
|
raise RuntimeError(
|
241
245
|
f"sample_packing is not supported when streaming is enabled, please set streaming to False"
|
@@ -530,7 +534,6 @@ def main():
|
|
530
534
|
SamplePackingCehrGptDataCollator,
|
531
535
|
cehrgpt_args.max_tokens_per_batch,
|
532
536
|
model_args.max_position_embeddings,
|
533
|
-
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
534
537
|
)
|
535
538
|
else:
|
536
539
|
trainer_class = Trainer
|
@@ -552,6 +555,7 @@ def main():
|
|
552
555
|
include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
|
553
556
|
motor_tte_vocab_size=model.config.motor_tte_vocab_size,
|
554
557
|
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
558
|
+
motor_sampling_probability=cehrgpt_args.motor_sampling_probability,
|
555
559
|
),
|
556
560
|
train_dataset=processed_dataset["train"],
|
557
561
|
eval_dataset=(
|
@@ -1,5 +1,7 @@
|
|
1
1
|
import dataclasses
|
2
|
-
from typing import List, Optional
|
2
|
+
from typing import List, Literal, Optional
|
3
|
+
|
4
|
+
from cehrgpt.models.gpt2 import ACT2FN
|
3
5
|
|
4
6
|
|
5
7
|
@dataclasses.dataclass
|
@@ -12,6 +14,14 @@ class CehrGPTArguments:
|
|
12
14
|
"help": "The path to the tokenized dataset created for the full population"
|
13
15
|
},
|
14
16
|
)
|
17
|
+
activation_function: Literal[tuple(ACT2FN.keys())] = dataclasses.field(
|
18
|
+
default="gelu_new",
|
19
|
+
metadata={"help": "The activation function to use"},
|
20
|
+
)
|
21
|
+
decoder_mlp: Literal["GPT2MLP", "LlamaMLP"] = dataclasses.field(
|
22
|
+
default="GPT2MLP",
|
23
|
+
metadata={"help": "The decoder MLP architecture"},
|
24
|
+
)
|
15
25
|
include_inpatient_hour_token: Optional[bool] = dataclasses.field(
|
16
26
|
default=True,
|
17
27
|
metadata={"help": "Include inpatient hour token"},
|
@@ -54,6 +64,14 @@ class CehrGPTArguments:
|
|
54
64
|
default=128,
|
55
65
|
metadata={"help": "The number of examples from the training set."},
|
56
66
|
)
|
67
|
+
hyperparameter_tuning: Optional[bool] = dataclasses.field(
|
68
|
+
default=False,
|
69
|
+
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
70
|
+
)
|
71
|
+
hyperparameter_tuning_is_grid: Optional[bool] = dataclasses.field(
|
72
|
+
default=True,
|
73
|
+
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
74
|
+
)
|
57
75
|
hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
|
58
76
|
default=0.1,
|
59
77
|
metadata={
|
@@ -66,10 +84,6 @@ class CehrGPTArguments:
|
|
66
84
|
"help": "The number of trails will be use for hyperparameter tuning."
|
67
85
|
},
|
68
86
|
)
|
69
|
-
hyperparameter_tuning: Optional[bool] = dataclasses.field(
|
70
|
-
default=False,
|
71
|
-
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
72
|
-
)
|
73
87
|
hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
|
74
88
|
default_factory=lambda: [4, 8, 16],
|
75
89
|
metadata={"help": "Hyperparameter search batch sizes"},
|
@@ -78,29 +92,13 @@ class CehrGPTArguments:
|
|
78
92
|
default_factory=lambda: [10],
|
79
93
|
metadata={"help": "Hyperparameter search num_train_epochs"},
|
80
94
|
)
|
81
|
-
|
82
|
-
|
83
|
-
metadata={
|
84
|
-
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
85
|
-
},
|
86
|
-
)
|
87
|
-
lr_high: Optional[float] = dataclasses.field(
|
88
|
-
default=5e-5,
|
89
|
-
metadata={
|
90
|
-
"help": "The upper bound of the learning rate range for hyperparameter tuning."
|
91
|
-
},
|
92
|
-
)
|
93
|
-
weight_decays_low: Optional[float] = dataclasses.field(
|
94
|
-
default=1e-3,
|
95
|
-
metadata={
|
96
|
-
"help": "The lower bound of the weight decays range for hyperparameter tuning."
|
97
|
-
},
|
95
|
+
hyperparameter_learning_rates: Optional[List[int]] = dataclasses.field(
|
96
|
+
default_factory=lambda: [1e-5],
|
97
|
+
metadata={"help": "Hyperparameter search learning rates"},
|
98
98
|
)
|
99
|
-
|
100
|
-
|
101
|
-
metadata={
|
102
|
-
"help": "The upper bound of the weight decays range for hyperparameter tuning."
|
103
|
-
},
|
99
|
+
hyperparameter_weight_decays: Optional[List[int]] = dataclasses.field(
|
100
|
+
default_factory=lambda: [1e-2],
|
101
|
+
metadata={"help": "Hyperparameter search learning rates"},
|
104
102
|
)
|
105
103
|
causal_sfm: Optional[bool] = dataclasses.field(
|
106
104
|
default=False,
|
@@ -168,6 +166,16 @@ class CehrGPTArguments:
|
|
168
166
|
"help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
|
169
167
|
},
|
170
168
|
)
|
169
|
+
inner_dim: Optional[int] = dataclasses.field(
|
170
|
+
default=None,
|
171
|
+
metadata={"help": "The dimensionality of the hidden layer"},
|
172
|
+
)
|
173
|
+
apply_rotary: Optional[bool] = dataclasses.field(
|
174
|
+
default=False,
|
175
|
+
metadata={
|
176
|
+
"help": "A flag to indicate whether we want to use rotary encoder layers"
|
177
|
+
},
|
178
|
+
)
|
171
179
|
sample_packing: Optional[bool] = dataclasses.field(
|
172
180
|
default=False,
|
173
181
|
metadata={
|
@@ -177,12 +185,6 @@ class CehrGPTArguments:
|
|
177
185
|
max_tokens_per_batch: int = dataclasses.field(
|
178
186
|
default=16384, metadata={"help": "Maximum number of tokens in each batch"}
|
179
187
|
)
|
180
|
-
add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
|
181
|
-
default=False,
|
182
|
-
metadata={
|
183
|
-
"help": "A flag to indicate whether we want to add end token in sample packing"
|
184
|
-
},
|
185
|
-
)
|
186
188
|
include_motor_time_to_event: Optional[bool] = dataclasses.field(
|
187
189
|
default=False,
|
188
190
|
metadata={
|
@@ -203,7 +205,17 @@ class CehrGPTArguments:
|
|
203
205
|
"help": "The number of times each motor_num_time_pieces piece has to be"
|
204
206
|
},
|
205
207
|
)
|
206
|
-
|
208
|
+
motor_use_ontology: Optional[bool] = dataclasses.field(
|
209
|
+
default=False,
|
210
|
+
metadata={
|
211
|
+
"help": "A flag to indicate whether we want to use motor_use_ontology"
|
212
|
+
},
|
213
|
+
)
|
214
|
+
motor_sampling_probability: Optional[float] = dataclasses.field(
|
215
|
+
default=0.0,
|
216
|
+
metadata={"help": "A flag to indicate whether we want to use sample packing"},
|
217
|
+
)
|
218
|
+
vocab_dir: Optional[str] = dataclasses.field(
|
207
219
|
default=None,
|
208
220
|
metadata={"help": "The directory where the concept data is stored."},
|
209
221
|
)
|