cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
cehrgpt/omop/ontology.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import collections
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, Iterable, Optional, Set, Union
|
6
|
+
|
7
|
+
import polars as pl
|
8
|
+
from datasets import Dataset
|
9
|
+
|
10
|
+
|
11
|
+
# Adapted from femr.ontology
|
12
|
+
def _get_all_codes_map(batch: Dict[str, Any]) -> Dict[str, Any]:
|
13
|
+
result = set()
|
14
|
+
for concept_ids in batch["concept_ids"]:
|
15
|
+
for concept_id in concept_ids:
|
16
|
+
if concept_id.isnumeric():
|
17
|
+
result.add(concept_id)
|
18
|
+
return {"unique_concept_ids": list(result)}
|
19
|
+
|
20
|
+
|
21
|
+
class Ontology:
|
22
|
+
def __init__(self, vocab_path: str):
|
23
|
+
"""Create an Ontology from an Athena download and an optional meds Code Metadata structure.
|
24
|
+
|
25
|
+
NOTE: This is an expensive operation.
|
26
|
+
It is recommended to create an ontology once and then save/load it as necessary.
|
27
|
+
"""
|
28
|
+
# Load from code metadata
|
29
|
+
self.parents_map: Dict[str, Set[str]] = collections.defaultdict(set)
|
30
|
+
self.concept_vocabulary_map: Dict[str, str] = collections.defaultdict(str)
|
31
|
+
self.concept_domain_map: Dict[str, str] = collections.defaultdict(str)
|
32
|
+
|
33
|
+
# Load from the athena path ...
|
34
|
+
concept = pl.scan_parquet(os.path.join(vocab_path, "concept/*parquet"))
|
35
|
+
vocabulary_id_col = pl.col("vocabulary_id")
|
36
|
+
concept_id_col = pl.col("concept_id").cast(pl.String)
|
37
|
+
domain_id_col = pl.col("domain_id").cast(pl.String)
|
38
|
+
|
39
|
+
processed_concepts = (
|
40
|
+
concept.select(
|
41
|
+
concept_id_col,
|
42
|
+
domain_id_col,
|
43
|
+
vocabulary_id_col,
|
44
|
+
pl.col("standard_concept").is_null(),
|
45
|
+
)
|
46
|
+
.collect()
|
47
|
+
.rows()
|
48
|
+
)
|
49
|
+
|
50
|
+
non_standard_concepts = set()
|
51
|
+
|
52
|
+
for concept_id, domain_id, vocabulary_id, is_non_standard in processed_concepts:
|
53
|
+
# We don't want to override code metadata
|
54
|
+
if concept_id not in self.concept_vocabulary_map:
|
55
|
+
self.concept_vocabulary_map[concept_id] = vocabulary_id
|
56
|
+
|
57
|
+
if concept_id not in self.concept_domain_map:
|
58
|
+
self.concept_domain_map[concept_id] = domain_id
|
59
|
+
# We don't want to override code metadata
|
60
|
+
if is_non_standard:
|
61
|
+
non_standard_concepts.add(concept_id)
|
62
|
+
|
63
|
+
relationship = pl.scan_parquet(
|
64
|
+
os.path.join(vocab_path, "concept_relationship/*parquet")
|
65
|
+
)
|
66
|
+
relationship_id = pl.col("relationship_id")
|
67
|
+
relationship = relationship.filter(
|
68
|
+
relationship_id == "Maps to",
|
69
|
+
pl.col("concept_id_1") != pl.col("concept_id_2"),
|
70
|
+
)
|
71
|
+
for concept_id_1, concept_id_2 in (
|
72
|
+
relationship.select(
|
73
|
+
pl.col("concept_id_1").cast(pl.String),
|
74
|
+
pl.col("concept_id_2").cast(pl.String),
|
75
|
+
)
|
76
|
+
.collect()
|
77
|
+
.rows()
|
78
|
+
):
|
79
|
+
if concept_id_1 in non_standard_concepts:
|
80
|
+
self.parents_map[concept_id_1].add(concept_id_2)
|
81
|
+
|
82
|
+
ancestor = pl.scan_parquet(
|
83
|
+
os.path.join(vocab_path, "concept_ancestor/*parquet")
|
84
|
+
)
|
85
|
+
ancestor = ancestor.filter(pl.col("min_levels_of_separation") == 1)
|
86
|
+
for concept_id, parent_concept_id in (
|
87
|
+
ancestor.select(
|
88
|
+
pl.col("descendant_concept_id").cast(pl.String),
|
89
|
+
pl.col("ancestor_concept_id").cast(pl.String),
|
90
|
+
)
|
91
|
+
.collect()
|
92
|
+
.rows()
|
93
|
+
):
|
94
|
+
self.parents_map[concept_id].add(parent_concept_id)
|
95
|
+
self.all_parents_map: Dict[str, Set[str]] = {}
|
96
|
+
|
97
|
+
def get_domain(self, concept_id: Union[str, int]) -> Optional[str]:
|
98
|
+
return self.concept_domain_map.get(str(concept_id), None)
|
99
|
+
|
100
|
+
def prune_to_dataset(
|
101
|
+
self,
|
102
|
+
dataset: Dataset,
|
103
|
+
remove_ontologies: Set[str] = set(),
|
104
|
+
num_proc: int = 4,
|
105
|
+
batch_size: int = 1024,
|
106
|
+
) -> None:
|
107
|
+
mapped_dataset = dataset.map(
|
108
|
+
_get_all_codes_map,
|
109
|
+
batched=True,
|
110
|
+
batch_size=batch_size,
|
111
|
+
remove_columns=dataset.column_names,
|
112
|
+
num_proc=num_proc,
|
113
|
+
)
|
114
|
+
valid_concept_ids = set(mapped_dataset["unique_concept_ids"])
|
115
|
+
all_parents = set()
|
116
|
+
for concept_id in valid_concept_ids:
|
117
|
+
all_parents |= self.get_all_parents(concept_id)
|
118
|
+
|
119
|
+
def is_valid(c: str):
|
120
|
+
ontology = self.concept_vocabulary_map.get(c, "")
|
121
|
+
return (c in valid_concept_ids) or (
|
122
|
+
(ontology not in remove_ontologies) and (c in all_parents)
|
123
|
+
)
|
124
|
+
|
125
|
+
concept_ids = set(self.parents_map.keys())
|
126
|
+
for concept_id in concept_ids:
|
127
|
+
m: Any
|
128
|
+
if is_valid(concept_id):
|
129
|
+
for m in (self.parents_map, self.concept_vocabulary_map):
|
130
|
+
m[concept_id] = {a for a in m[concept_id] if is_valid(a)}
|
131
|
+
else:
|
132
|
+
for m in (self.parents_map, self.concept_vocabulary_map):
|
133
|
+
if concept_id in m:
|
134
|
+
del m[concept_id]
|
135
|
+
|
136
|
+
self.all_parents_map = {}
|
137
|
+
|
138
|
+
# Prime the pump
|
139
|
+
for concept_id in self.parents_map.keys():
|
140
|
+
self.get_all_parents(concept_id)
|
141
|
+
|
142
|
+
def get_parents(self, code: str) -> Iterable[str]:
|
143
|
+
"""Get the parents for a given code."""
|
144
|
+
return self.parents_map.get(code, set())
|
145
|
+
|
146
|
+
def get_all_parents(self, code: str) -> Set[str]:
|
147
|
+
"""Get all parents, including through the ontology."""
|
148
|
+
if code not in self.all_parents_map:
|
149
|
+
result = {code}
|
150
|
+
for parent in self.parents_map.get(code, set()):
|
151
|
+
result |= self.get_all_parents(parent)
|
152
|
+
self.all_parents_map[code] = result
|
153
|
+
|
154
|
+
return self.all_parents_map[code]
|
cehrgpt/runners/data_utils.py
CHANGED
@@ -47,7 +47,7 @@ def prepare_finetune_dataset(
|
|
47
47
|
data_args: DataTrainingArguments,
|
48
48
|
training_args: TrainingArguments,
|
49
49
|
cehrgpt_args: CehrGPTArguments,
|
50
|
-
cache_file_collector: CacheFileCollector,
|
50
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
51
51
|
) -> DatasetDict:
|
52
52
|
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
53
53
|
if data_args.is_data_in_meds:
|
@@ -91,8 +91,9 @@ def prepare_finetune_dataset(
|
|
91
91
|
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
92
92
|
stats,
|
93
93
|
)
|
94
|
-
|
95
|
-
|
94
|
+
if cache_file_collector:
|
95
|
+
# Clean up the files created from the data generator
|
96
|
+
cache_file_collector.remove_cache_files()
|
96
97
|
dataset = load_from_disk(str(meds_extension_path))
|
97
98
|
|
98
99
|
train_set = dataset["train"]
|
@@ -271,7 +272,7 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
|
271
272
|
def extract_cohort_sequences(
|
272
273
|
data_args: DataTrainingArguments,
|
273
274
|
cehrgpt_args: CehrGPTArguments,
|
274
|
-
cache_file_collector: CacheFileCollector,
|
275
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
275
276
|
) -> DatasetDict:
|
276
277
|
"""
|
277
278
|
Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
|
@@ -309,9 +310,18 @@ def extract_cohort_sequences(
|
|
309
310
|
mapping={
|
310
311
|
"prediction_time": "index_date",
|
311
312
|
"subject_id": "person_id",
|
313
|
+
"boolean_value": "label",
|
312
314
|
}
|
313
315
|
)
|
314
316
|
all_person_ids = cohort["person_id"].unique().to_list()
|
317
|
+
# In case the label column does not exist, we add a fake column to the dataframe so subsequent process can work
|
318
|
+
if "label" not in cohort.columns:
|
319
|
+
cohort = cohort.with_columns(
|
320
|
+
pl.Series(
|
321
|
+
name="label", values=np.zeros_like(cohort["person_id"].to_numpy())
|
322
|
+
)
|
323
|
+
)
|
324
|
+
|
315
325
|
# data_args.observation_window
|
316
326
|
tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
|
317
327
|
filtered_tokenized_dataset = tokenized_dataset.filter(
|
@@ -353,6 +363,7 @@ def extract_cohort_sequences(
|
|
353
363
|
num_proc=data_args.preprocessing_num_workers,
|
354
364
|
remove_columns=filtered_tokenized_dataset["train"].column_names,
|
355
365
|
)
|
356
|
-
cache_file_collector
|
357
|
-
|
366
|
+
if cache_file_collector:
|
367
|
+
cache_file_collector.add_cache_files(filtered_tokenized_dataset)
|
368
|
+
cache_file_collector.add_cache_files(processed_dataset)
|
358
369
|
return processed_dataset
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import glob
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import random
|
@@ -175,11 +176,6 @@ def model_init(
|
|
175
176
|
model.config.class_weights = cehrgpt_args.class_weights
|
176
177
|
LOG.info(f"Setting class_weights to {model.config.class_weights}")
|
177
178
|
|
178
|
-
# Enable position embeddings when position embeddings are disabled in pre-training
|
179
|
-
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
180
|
-
LOG.info(f"Enable the position_embeddings")
|
181
|
-
model.cehrgpt.enable_position_embeddings()
|
182
|
-
|
183
179
|
if model.config.max_position_embeddings < model_args.max_position_embeddings:
|
184
180
|
LOG.info(
|
185
181
|
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
@@ -379,7 +375,6 @@ def main():
|
|
379
375
|
SamplePackingCehrGptDataCollator,
|
380
376
|
cehrgpt_args.max_tokens_per_batch,
|
381
377
|
config.max_position_embeddings,
|
382
|
-
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
383
378
|
)
|
384
379
|
else:
|
385
380
|
trainer_class = Trainer
|
@@ -406,8 +401,9 @@ def main():
|
|
406
401
|
)
|
407
402
|
|
408
403
|
if training_args.do_train:
|
404
|
+
output_dir = training_args.output_dir
|
409
405
|
if cehrgpt_args.hyperparameter_tuning:
|
410
|
-
training_args = perform_hyperparameter_search(
|
406
|
+
training_args, run_id = perform_hyperparameter_search(
|
411
407
|
trainer_class,
|
412
408
|
partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
|
413
409
|
processed_dataset,
|
@@ -416,18 +412,28 @@ def main():
|
|
416
412
|
model_args,
|
417
413
|
cehrgpt_args,
|
418
414
|
)
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
retrain_with_full_set(
|
423
|
-
trainer_class,
|
424
|
-
model_args,
|
425
|
-
training_args,
|
426
|
-
cehrgpt_args,
|
427
|
-
tokenizer,
|
428
|
-
processed_dataset,
|
429
|
-
data_collator,
|
415
|
+
# We enforce retraining if cehrgpt_args.hyperparameter_tuning_percentage < 1.0
|
416
|
+
cehrgpt_args.retrain_with_full |= (
|
417
|
+
cehrgpt_args.hyperparameter_tuning_percentage < 1.0
|
430
418
|
)
|
419
|
+
output_dir = os.path.join(training_args.output_dir, f"run-{run_id}")
|
420
|
+
|
421
|
+
if cehrgpt_args.hyperparameter_tuning and not cehrgpt_args.retrain_with_full:
|
422
|
+
folders = glob.glob(os.path.join(output_dir, "checkpoint-*"))
|
423
|
+
if len(folders) == 0:
|
424
|
+
raise RuntimeError(
|
425
|
+
f"There must be a checkpoint folder under {output_dir}"
|
426
|
+
)
|
427
|
+
checkpoint_dir = folders[0]
|
428
|
+
LOG.info("Best trial checkpoint folder: %s", checkpoint_dir)
|
429
|
+
for file_name in os.listdir(checkpoint_dir):
|
430
|
+
try:
|
431
|
+
full_file_name = os.path.join(checkpoint_dir, file_name)
|
432
|
+
destination = os.path.join(training_args.output_dir, file_name)
|
433
|
+
if os.path.isfile(full_file_name):
|
434
|
+
shutil.copy2(full_file_name, destination)
|
435
|
+
except Exception as e:
|
436
|
+
LOG.error("Failed to copy %s: %s", file_name, str(e))
|
431
437
|
else:
|
432
438
|
# Initialize Trainer for final training on the combined train+val set
|
433
439
|
trainer = trainer_class(
|
@@ -476,63 +482,6 @@ def main():
|
|
476
482
|
do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
|
477
483
|
|
478
484
|
|
479
|
-
def retrain_with_full_set(
|
480
|
-
trainer_class,
|
481
|
-
model_args: ModelArguments,
|
482
|
-
training_args: TrainingArguments,
|
483
|
-
cehrgpt_args: CehrGPTArguments,
|
484
|
-
tokenizer: CehrGptTokenizer,
|
485
|
-
dataset: DatasetDict,
|
486
|
-
data_collator: CehrGptDataCollator,
|
487
|
-
) -> None:
|
488
|
-
"""
|
489
|
-
Retrains a model on the full training and validation dataset for final performance evaluation.
|
490
|
-
|
491
|
-
This function consolidates the training and validation datasets into a single
|
492
|
-
dataset for final model training, updates the output directory for the final model,
|
493
|
-
and disables evaluation during training. It resumes from the latest checkpoint if available,
|
494
|
-
trains the model on the combined dataset, and saves the model along with training metrics
|
495
|
-
and state information.
|
496
|
-
|
497
|
-
Args:
|
498
|
-
trainer_class: Trainer or its subclass
|
499
|
-
model_args (ModelArguments): Model configuration and hyperparameters.
|
500
|
-
training_args (TrainingArguments): Training configuration, including output directory,
|
501
|
-
evaluation strategy, and other training parameters.
|
502
|
-
cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
|
503
|
-
tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
|
504
|
-
dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
|
505
|
-
data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
|
506
|
-
|
507
|
-
Returns:
|
508
|
-
None
|
509
|
-
"""
|
510
|
-
# Initialize Trainer for final training on the combined train+val set
|
511
|
-
full_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])
|
512
|
-
training_args.output_dir = os.path.join(training_args.output_dir, "full")
|
513
|
-
LOG.info(
|
514
|
-
"Final output_dir for final_training_args.output_dir %s",
|
515
|
-
training_args.output_dir,
|
516
|
-
)
|
517
|
-
Path(training_args.output_dir).mkdir(exist_ok=True)
|
518
|
-
# Disable evaluation
|
519
|
-
training_args.evaluation_strategy = "no"
|
520
|
-
checkpoint = get_last_hf_checkpoint(training_args)
|
521
|
-
final_trainer = trainer_class(
|
522
|
-
model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
|
523
|
-
data_collator=data_collator,
|
524
|
-
args=training_args,
|
525
|
-
train_dataset=full_dataset,
|
526
|
-
tokenizer=tokenizer,
|
527
|
-
)
|
528
|
-
final_train_result = final_trainer.train(resume_from_checkpoint=checkpoint)
|
529
|
-
final_trainer.save_model() # Saves the tokenizer too for easy upload
|
530
|
-
metrics = final_train_result.metrics
|
531
|
-
final_trainer.log_metrics("train", metrics)
|
532
|
-
final_trainer.save_metrics("train", metrics)
|
533
|
-
final_trainer.save_state()
|
534
|
-
|
535
|
-
|
536
485
|
def do_predict(
|
537
486
|
test_dataloader: DataLoader,
|
538
487
|
model_args: ModelArguments,
|
@@ -580,7 +529,15 @@ def do_predict(
|
|
580
529
|
index_dates = batch.pop("index_date").numpy().squeeze()
|
581
530
|
if index_dates.ndim == 0:
|
582
531
|
index_dates = np.asarray([index_dates])
|
583
|
-
|
532
|
+
|
533
|
+
index_dates = list(
|
534
|
+
map(
|
535
|
+
lambda posix_time: datetime.utcfromtimestamp(posix_time).replace(
|
536
|
+
tzinfo=None
|
537
|
+
),
|
538
|
+
index_dates.tolist(),
|
539
|
+
)
|
540
|
+
)
|
584
541
|
|
585
542
|
batch = {k: v.to(device) for k, v in batch.items()}
|
586
543
|
# Forward pass
|
@@ -644,9 +601,6 @@ def load_lora_model(
|
|
644
601
|
# Enable include_values when include_values is set to be False during pre-training
|
645
602
|
if model_args.include_values and not model.cehrgpt.include_values:
|
646
603
|
model.cehrgpt.include_values = True
|
647
|
-
# Enable position embeddings when position embeddings are disabled in pre-training
|
648
|
-
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
649
|
-
model.cehrgpt.exclude_position_ids = False
|
650
604
|
if cehrgpt_args.expand_tokenizer:
|
651
605
|
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
652
606
|
# Expand tokenizer to adapt to the finetuning dataset
|
@@ -2,7 +2,9 @@ import os
|
|
2
2
|
from functools import partial
|
3
3
|
from typing import Optional, Union
|
4
4
|
|
5
|
+
import datasets
|
5
6
|
import numpy as np
|
7
|
+
import pandas as pd
|
6
8
|
import torch
|
7
9
|
import torch.distributed as dist
|
8
10
|
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
@@ -20,7 +22,7 @@ from cehrbert.runners.runner_util import (
|
|
20
22
|
load_parquet_as_dataset,
|
21
23
|
)
|
22
24
|
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
|
23
|
-
from transformers import EarlyStoppingCallback, Trainer,
|
25
|
+
from transformers import EarlyStoppingCallback, Trainer, set_seed
|
24
26
|
from transformers.trainer_utils import is_main_process
|
25
27
|
from transformers.utils import is_flash_attn_2_available, logging
|
26
28
|
|
@@ -34,6 +36,7 @@ from cehrgpt.models.config import CEHRGPTConfig
|
|
34
36
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
35
37
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
36
38
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
39
|
+
from cehrgpt.omop.ontology import Ontology
|
37
40
|
from cehrgpt.runners.data_utils import get_torch_dtype
|
38
41
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
39
42
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
@@ -70,68 +73,64 @@ def load_and_create_tokenizer(
|
|
70
73
|
data_args: DataTrainingArguments,
|
71
74
|
model_args: ModelArguments,
|
72
75
|
cehrgpt_args: CehrGPTArguments,
|
73
|
-
dataset:
|
76
|
+
dataset: Union[Dataset, DatasetDict],
|
74
77
|
) -> CehrGptTokenizer:
|
75
78
|
|
76
|
-
concept_name_mapping = {}
|
77
|
-
allowed_motor_codes = list()
|
78
|
-
if cehrgpt_args.concept_dir:
|
79
|
-
import pandas as pd
|
80
|
-
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
81
|
-
from meds.schema import death_code
|
82
|
-
|
83
|
-
LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
|
84
|
-
concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
|
85
|
-
LOG.info(
|
86
|
-
"Creating concept name mapping and motor_time_to_event_codes from disk at %s",
|
87
|
-
cehrgpt_args.concept_dir,
|
88
|
-
)
|
89
|
-
for row in concept_pd.itertuples():
|
90
|
-
concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
|
91
|
-
row, "concept_name"
|
92
|
-
)
|
93
|
-
if (
|
94
|
-
cehrgpt_args.include_motor_time_to_event
|
95
|
-
and getattr(row, "domain_id")
|
96
|
-
in ["Condition", "Procedure", "Drug", "Visit"]
|
97
|
-
and getattr(row, "standard_concept") == "S"
|
98
|
-
):
|
99
|
-
allowed_motor_codes.append(str(getattr(row, "concept_id")))
|
100
|
-
LOG.info(
|
101
|
-
"Adding death codes for MOTOR TTE predictions: %s",
|
102
|
-
[DEATH_TOKEN, death_code],
|
103
|
-
)
|
104
|
-
allowed_motor_codes.extend([DEATH_TOKEN, death_code])
|
105
79
|
# Try to load the pretrained tokenizer
|
106
80
|
tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
|
107
|
-
|
108
|
-
|
109
|
-
except Exception as e:
|
110
|
-
LOG.warning(e)
|
111
|
-
if dataset is None:
|
81
|
+
if not tokenizer_exists(tokenizer_abspath):
|
82
|
+
if cehrgpt_args.include_motor_time_to_event and not cehrgpt_args.vocab_dir:
|
112
83
|
raise RuntimeError(
|
113
|
-
|
114
|
-
|
84
|
+
"motor_vocab_dir must be specified if include_motor_time_to_event is True"
|
85
|
+
)
|
86
|
+
ontology: Optional[Ontology] = None
|
87
|
+
concept_name_mapping = {}
|
88
|
+
if cehrgpt_args.vocab_dir:
|
89
|
+
LOG.info("Loading concept data from disk at %s", cehrgpt_args.vocab_dir)
|
90
|
+
concept_pd = pd.read_parquet(
|
91
|
+
os.path.join(cehrgpt_args.vocab_dir, "concept")
|
115
92
|
)
|
93
|
+
for row in concept_pd.itertuples():
|
94
|
+
concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
|
95
|
+
row, "concept_name"
|
96
|
+
)
|
97
|
+
|
98
|
+
if cehrgpt_args.motor_use_ontology:
|
99
|
+
LOG.info("Creating ontology for MOTOR TTE predictions")
|
100
|
+
ontology = Ontology(cehrgpt_args.vocab_dir)
|
101
|
+
train_val_dataset = datasets.concatenate_datasets(
|
102
|
+
[dataset["train"], dataset["validation"]]
|
103
|
+
)
|
104
|
+
ontology.prune_to_dataset(
|
105
|
+
train_val_dataset,
|
106
|
+
num_proc=data_args.preprocessing_num_workers,
|
107
|
+
remove_ontologies={"SPL", "HemOnc", "LOINC"},
|
108
|
+
)
|
109
|
+
|
116
110
|
LOG.info("Started training the tokenizer ...")
|
111
|
+
train_val_dataset = datasets.concatenate_datasets(
|
112
|
+
[dataset["train"], dataset["validation"]]
|
113
|
+
)
|
117
114
|
tokenizer = CehrGptTokenizer.train_tokenizer(
|
118
|
-
|
115
|
+
train_val_dataset,
|
119
116
|
concept_name_mapping,
|
120
117
|
data_args,
|
121
118
|
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
122
|
-
|
123
|
-
(
|
119
|
+
num_motor_tasks=(
|
124
120
|
cehrgpt_args.num_motor_tasks
|
125
121
|
if cehrgpt_args.include_motor_time_to_event
|
126
122
|
else None
|
127
123
|
),
|
128
124
|
apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
|
129
125
|
min_prevalence=cehrgpt_args.min_prevalence,
|
126
|
+
ontology=ontology,
|
130
127
|
)
|
131
128
|
LOG.info("Finished training the tokenizer ...")
|
132
129
|
tokenizer.save_pretrained(tokenizer_abspath)
|
133
130
|
LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
|
134
|
-
|
131
|
+
else:
|
132
|
+
LOG.info("The tokenizer exists and will be loaded from %s", tokenizer_abspath)
|
133
|
+
tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
|
135
134
|
return tokenizer
|
136
135
|
|
137
136
|
|
@@ -187,7 +186,10 @@ def load_and_create_model(
|
|
187
186
|
|
188
187
|
model_args_cehrgpt = model_args.as_dict()
|
189
188
|
model_args_cehrgpt.pop("attn_implementation")
|
189
|
+
# CEHR-GPT does not support this anymore
|
190
|
+
model_args_cehrgpt.pop("exclude_position_ids")
|
190
191
|
model_config = CEHRGPTConfig(
|
192
|
+
activation_function=cehrgpt_args.activation_function,
|
191
193
|
vocab_size=tokenizer.vocab_size,
|
192
194
|
value_vocab_size=tokenizer.value_vocab_size,
|
193
195
|
time_token_vocab_size=tokenizer.time_token_vocab_size,
|
@@ -207,6 +209,7 @@ def load_and_create_model(
|
|
207
209
|
n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
|
208
210
|
use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
|
209
211
|
pretrained_embedding_dim=pretrained_embedding_dim,
|
212
|
+
apply_rotary=cehrgpt_args.apply_rotary,
|
210
213
|
sample_packing_max_positions=(
|
211
214
|
cehrgpt_args.max_tokens_per_batch
|
212
215
|
if cehrgpt_args.sample_packing
|
@@ -217,6 +220,8 @@ def load_and_create_model(
|
|
217
220
|
motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
|
218
221
|
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
219
222
|
ve_token_id=tokenizer.ve_token_id,
|
223
|
+
n_inner=cehrgpt_args.inner_dim,
|
224
|
+
decoder_mlp=cehrgpt_args.decoder_mlp,
|
220
225
|
**model_args_cehrgpt,
|
221
226
|
)
|
222
227
|
|
@@ -235,7 +240,6 @@ def load_and_create_model(
|
|
235
240
|
|
236
241
|
def main():
|
237
242
|
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
238
|
-
|
239
243
|
if cehrgpt_args.sample_packing and data_args.streaming:
|
240
244
|
raise RuntimeError(
|
241
245
|
f"sample_packing is not supported when streaming is enabled, please set streaming to False"
|
@@ -530,7 +534,6 @@ def main():
|
|
530
534
|
SamplePackingCehrGptDataCollator,
|
531
535
|
cehrgpt_args.max_tokens_per_batch,
|
532
536
|
model_args.max_position_embeddings,
|
533
|
-
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
534
537
|
)
|
535
538
|
else:
|
536
539
|
trainer_class = Trainer
|
@@ -552,6 +555,7 @@ def main():
|
|
552
555
|
include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
|
553
556
|
motor_tte_vocab_size=model.config.motor_tte_vocab_size,
|
554
557
|
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
558
|
+
motor_sampling_probability=cehrgpt_args.motor_sampling_probability,
|
555
559
|
),
|
556
560
|
train_dataset=processed_dataset["train"],
|
557
561
|
eval_dataset=(
|