cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,746 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import shutil
|
5
|
+
from datetime import datetime
|
6
|
+
from functools import partial
|
7
|
+
from pathlib import Path
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
import torch
|
12
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
13
|
+
create_dataset_from_meds_reader,
|
14
|
+
)
|
15
|
+
from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
|
16
|
+
from cehrbert.runners.hf_runner_argument_dataclass import (
|
17
|
+
DataTrainingArguments,
|
18
|
+
FineTuneModelType,
|
19
|
+
ModelArguments,
|
20
|
+
)
|
21
|
+
from cehrbert.runners.runner_util import (
|
22
|
+
generate_prepared_ds_path,
|
23
|
+
get_last_hf_checkpoint,
|
24
|
+
get_meds_extension_path,
|
25
|
+
load_parquet_as_dataset,
|
26
|
+
)
|
27
|
+
from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
28
|
+
from peft import LoraConfig, PeftModel, get_peft_model
|
29
|
+
from scipy.special import expit as sigmoid
|
30
|
+
from torch.utils.data import DataLoader
|
31
|
+
from tqdm import tqdm
|
32
|
+
from transformers import (
|
33
|
+
EarlyStoppingCallback,
|
34
|
+
Trainer,
|
35
|
+
TrainerCallback,
|
36
|
+
TrainerControl,
|
37
|
+
TrainerState,
|
38
|
+
TrainingArguments,
|
39
|
+
set_seed,
|
40
|
+
)
|
41
|
+
from transformers.tokenization_utils_base import LARGE_INTEGER
|
42
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
43
|
+
|
44
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
45
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
46
|
+
from cehrgpt.models.hf_cehrgpt import (
|
47
|
+
CEHRGPTConfig,
|
48
|
+
CehrGptForClassification,
|
49
|
+
CEHRGPTPreTrainedModel,
|
50
|
+
)
|
51
|
+
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
52
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
53
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
54
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
55
|
+
from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
|
56
|
+
|
57
|
+
LOG = logging.get_logger("transformers")
|
58
|
+
|
59
|
+
|
60
|
+
class UpdateNumEpochsBeforeEarlyStoppingCallback(TrainerCallback):
|
61
|
+
"""
|
62
|
+
Callback to update metrics with the number of epochs completed before early stopping.
|
63
|
+
|
64
|
+
based on the best evaluation metric (e.g., eval_loss).
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(self, model_folder: str):
|
68
|
+
self._model_folder = model_folder
|
69
|
+
self._metrics_path = os.path.join(
|
70
|
+
model_folder, "num_epochs_trained_before_early_stopping.json"
|
71
|
+
)
|
72
|
+
self._num_epochs_before_early_stopping = 0
|
73
|
+
self._best_val_loss = float("inf")
|
74
|
+
|
75
|
+
@property
|
76
|
+
def num_epochs_before_early_stopping(self):
|
77
|
+
return self._num_epochs_before_early_stopping
|
78
|
+
|
79
|
+
def on_train_begin(
|
80
|
+
self,
|
81
|
+
args: TrainingArguments,
|
82
|
+
state: TrainerState,
|
83
|
+
control: TrainerControl,
|
84
|
+
**kwargs,
|
85
|
+
):
|
86
|
+
if os.path.exists(self._metrics_path):
|
87
|
+
with open(self._metrics_path, "r") as f:
|
88
|
+
metrics = json.load(f)
|
89
|
+
self._num_epochs_before_early_stopping = metrics[
|
90
|
+
"num_epochs_before_early_stopping"
|
91
|
+
]
|
92
|
+
self._best_val_loss = metrics["best_val_loss"]
|
93
|
+
|
94
|
+
def on_evaluate(self, args, state, control, **kwargs):
|
95
|
+
# Ensure metrics is available in kwargs
|
96
|
+
metrics = kwargs.get("metrics")
|
97
|
+
if metrics is not None and "eval_loss" in metrics:
|
98
|
+
# Check and update if a new best metric is achieved
|
99
|
+
if metrics["eval_loss"] < self._best_val_loss:
|
100
|
+
self._num_epochs_before_early_stopping = round(state.epoch)
|
101
|
+
self._best_val_loss = metrics["eval_loss"]
|
102
|
+
|
103
|
+
def on_save(
|
104
|
+
self,
|
105
|
+
args: TrainingArguments,
|
106
|
+
state: TrainerState,
|
107
|
+
control: TrainerControl,
|
108
|
+
**kwargs,
|
109
|
+
):
|
110
|
+
with open(self._metrics_path, "w") as f:
|
111
|
+
json.dump(
|
112
|
+
{
|
113
|
+
"num_epochs_before_early_stopping": self._num_epochs_before_early_stopping,
|
114
|
+
"best_val_loss": self._best_val_loss,
|
115
|
+
},
|
116
|
+
f,
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
def load_pretrained_tokenizer(
|
121
|
+
model_args,
|
122
|
+
) -> CehrGptTokenizer:
|
123
|
+
try:
|
124
|
+
return CehrGptTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
|
125
|
+
except Exception:
|
126
|
+
raise ValueError(
|
127
|
+
f"Can not load the pretrained tokenizer from {model_args.tokenizer_name_or_path}"
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
def load_finetuned_model(
|
132
|
+
model_args: ModelArguments,
|
133
|
+
training_args: TrainingArguments,
|
134
|
+
model_name_or_path: str,
|
135
|
+
) -> CEHRGPTPreTrainedModel:
|
136
|
+
if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
|
137
|
+
finetune_model_cls = CehrGptForClassification
|
138
|
+
else:
|
139
|
+
raise ValueError(
|
140
|
+
f"finetune_model_type can be one of the following types {FineTuneModelType.POOLING.value}"
|
141
|
+
)
|
142
|
+
|
143
|
+
attn_implementation = (
|
144
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
145
|
+
)
|
146
|
+
torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
|
147
|
+
# Try to create a new model based on the base model
|
148
|
+
try:
|
149
|
+
return finetune_model_cls.from_pretrained(
|
150
|
+
model_name_or_path,
|
151
|
+
attn_implementation=attn_implementation,
|
152
|
+
torch_dtype=torch_dtype,
|
153
|
+
)
|
154
|
+
except ValueError:
|
155
|
+
raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
|
156
|
+
|
157
|
+
|
158
|
+
def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
159
|
+
"""
|
160
|
+
Creates training, validation, and testing dataset splits based on specified splitting strategies.
|
161
|
+
|
162
|
+
This function splits a dataset into training, validation, and test sets, using either chronological,
|
163
|
+
patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
|
164
|
+
|
165
|
+
- **Chronological split**: Sorts by a specified date and splits based on historical and future data.
|
166
|
+
- **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
|
167
|
+
- **Random split**: Performs a straightforward random split of the dataset.
|
168
|
+
|
169
|
+
If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
|
170
|
+
the test set is created by further splitting the validation set based on `test_eval_ratio`.
|
171
|
+
|
172
|
+
Parameters:
|
173
|
+
data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
|
174
|
+
- `data_folder` (str): Path to the main dataset.
|
175
|
+
- `test_data_folder` (str, optional): Path to an optional test dataset.
|
176
|
+
- `chronological_split` (bool): Whether to split chronologically.
|
177
|
+
- `split_by_patient` (bool): Whether to split by unique patient IDs.
|
178
|
+
- `validation_split_percentage` (float): Percentage of data to use for validation.
|
179
|
+
- `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
|
180
|
+
- `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
|
181
|
+
- `preprocessing_batch_size` (int): Batch size for batched operations.
|
182
|
+
seed (int): Random seed for reproducibility of splits.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
Tuple[Dataset, Dataset, Dataset]: A tuple containing:
|
186
|
+
- `train_set` (Dataset): Training split of the dataset.
|
187
|
+
- `validation_set` (Dataset): Validation split of the dataset.
|
188
|
+
- `test_set` (Dataset): Test split of the dataset.
|
189
|
+
|
190
|
+
Raises:
|
191
|
+
FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
|
192
|
+
ValueError: If incompatible arguments are passed for splitting strategies.
|
193
|
+
|
194
|
+
Example Usage:
|
195
|
+
data_args = DataTrainingArguments(
|
196
|
+
data_folder="data/",
|
197
|
+
validation_split_percentage=0.1,
|
198
|
+
test_eval_ratio=0.2,
|
199
|
+
chronological_split=True
|
200
|
+
)
|
201
|
+
train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
|
202
|
+
"""
|
203
|
+
dataset = load_parquet_as_dataset(data_args.data_folder)
|
204
|
+
test_set = (
|
205
|
+
None
|
206
|
+
if not data_args.test_data_folder
|
207
|
+
else load_parquet_as_dataset(data_args.test_data_folder)
|
208
|
+
)
|
209
|
+
|
210
|
+
if data_args.chronological_split:
|
211
|
+
# Chronological split by sorting on `index_date`
|
212
|
+
dataset = dataset.sort("index_date")
|
213
|
+
total_size = len(dataset)
|
214
|
+
train_end = int((1 - data_args.validation_split_percentage) * total_size)
|
215
|
+
|
216
|
+
# Perform the split
|
217
|
+
train_set = dataset.select(range(0, train_end))
|
218
|
+
validation_set = dataset.select(range(train_end, total_size))
|
219
|
+
|
220
|
+
if test_set is None:
|
221
|
+
test_valid_split = validation_set.train_test_split(
|
222
|
+
test_size=data_args.test_eval_ratio, seed=seed
|
223
|
+
)
|
224
|
+
validation_set, test_set = (
|
225
|
+
test_valid_split["train"],
|
226
|
+
test_valid_split["test"],
|
227
|
+
)
|
228
|
+
|
229
|
+
elif data_args.split_by_patient:
|
230
|
+
# Patient-based split
|
231
|
+
LOG.info("Using the split_by_patient strategy")
|
232
|
+
unique_patient_ids = dataset.unique("person_id")
|
233
|
+
LOG.info(f"There are {len(unique_patient_ids)} patients in total")
|
234
|
+
|
235
|
+
np.random.seed(seed)
|
236
|
+
np.random.shuffle(unique_patient_ids)
|
237
|
+
|
238
|
+
train_end = int(
|
239
|
+
len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
|
240
|
+
)
|
241
|
+
train_patient_ids = set(unique_patient_ids[:train_end])
|
242
|
+
|
243
|
+
if test_set is None:
|
244
|
+
validation_end = int(
|
245
|
+
train_end
|
246
|
+
+ len(unique_patient_ids)
|
247
|
+
* data_args.validation_split_percentage
|
248
|
+
* data_args.test_eval_ratio
|
249
|
+
)
|
250
|
+
val_patient_ids = set(unique_patient_ids[train_end:validation_end])
|
251
|
+
test_patient_ids = set(unique_patient_ids[validation_end:])
|
252
|
+
else:
|
253
|
+
val_patient_ids, test_patient_ids = (
|
254
|
+
set(unique_patient_ids[train_end:]),
|
255
|
+
None,
|
256
|
+
)
|
257
|
+
|
258
|
+
# Helper function to apply patient-based filtering
|
259
|
+
def filter_by_patient_ids(patient_ids):
|
260
|
+
return dataset.filter(
|
261
|
+
lambda batch: [pid in patient_ids for pid in batch["person_id"]],
|
262
|
+
num_proc=data_args.preprocessing_num_workers,
|
263
|
+
batched=True,
|
264
|
+
batch_size=data_args.preprocessing_batch_size,
|
265
|
+
)
|
266
|
+
|
267
|
+
# Generate splits
|
268
|
+
train_set = filter_by_patient_ids(train_patient_ids)
|
269
|
+
validation_set = filter_by_patient_ids(val_patient_ids)
|
270
|
+
if test_set is None:
|
271
|
+
test_set = filter_by_patient_ids(test_patient_ids)
|
272
|
+
|
273
|
+
else:
|
274
|
+
# Random split
|
275
|
+
train_val = dataset.train_test_split(
|
276
|
+
test_size=data_args.validation_split_percentage, seed=seed
|
277
|
+
)
|
278
|
+
train_set, validation_set = train_val["train"], train_val["test"]
|
279
|
+
|
280
|
+
if test_set is None:
|
281
|
+
test_valid_split = validation_set.train_test_split(
|
282
|
+
test_size=data_args.test_eval_ratio, seed=seed
|
283
|
+
)
|
284
|
+
validation_set, test_set = (
|
285
|
+
test_valid_split["train"],
|
286
|
+
test_valid_split["test"],
|
287
|
+
)
|
288
|
+
|
289
|
+
return train_set, validation_set, test_set
|
290
|
+
|
291
|
+
|
292
|
+
def model_init(
|
293
|
+
model_args: ModelArguments,
|
294
|
+
training_args: TrainingArguments,
|
295
|
+
tokenizer: CehrGptTokenizer,
|
296
|
+
):
|
297
|
+
model = load_finetuned_model(
|
298
|
+
model_args, training_args, model_args.model_name_or_path
|
299
|
+
)
|
300
|
+
if model.config.max_position_embeddings < model_args.max_position_embeddings:
|
301
|
+
LOG.info(
|
302
|
+
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
303
|
+
)
|
304
|
+
model.config.max_position_embeddings = model_args.max_position_embeddings
|
305
|
+
model.resize_position_embeddings(model_args.max_position_embeddings)
|
306
|
+
# Enable include_values when include_values is set to be False during pre-training
|
307
|
+
if model_args.include_values and not model.cehrgpt.include_values:
|
308
|
+
model.cehrgpt.include_values = True
|
309
|
+
# Enable position embeddings when position embeddings are disabled in pre-training
|
310
|
+
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
311
|
+
model.cehrgpt.exclude_position_ids = False
|
312
|
+
# Expand tokenizer to adapt to the finetuning dataset
|
313
|
+
if model.config.vocab_size < tokenizer.vocab_size:
|
314
|
+
model.resize_token_embeddings(tokenizer.vocab_size)
|
315
|
+
# Update the pretrained embedding weights if they are available
|
316
|
+
if model.config.use_pretrained_embeddings:
|
317
|
+
model.cehrgpt.update_pretrained_embeddings(
|
318
|
+
tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
|
319
|
+
)
|
320
|
+
elif tokenizer.pretrained_token_ids:
|
321
|
+
model.config.pretrained_embedding_dim = (
|
322
|
+
tokenizer.pretrained_embeddings.shape[1]
|
323
|
+
)
|
324
|
+
model.config.use_pretrained_embeddings = True
|
325
|
+
model.cehrgpt.initialize_pretrained_embeddings()
|
326
|
+
model.cehrgpt.update_pretrained_embeddings(
|
327
|
+
tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
|
328
|
+
)
|
329
|
+
# Expand value tokenizer to adapt to the fine-tuning dataset
|
330
|
+
if model.config.include_values:
|
331
|
+
if model.config.value_vocab_size < tokenizer.value_vocab_size:
|
332
|
+
model.resize_value_embeddings(tokenizer.value_vocab_size)
|
333
|
+
# If lora is enabled, we add LORA adapters to the model
|
334
|
+
if model_args.use_lora:
|
335
|
+
# When LORA is used, the trainer could not automatically find this label,
|
336
|
+
# therefore we need to manually set label_names to "classifier_label" so the model
|
337
|
+
# can compute the loss during the evaluation
|
338
|
+
if training_args.label_names:
|
339
|
+
training_args.label_names.append("classifier_label")
|
340
|
+
else:
|
341
|
+
training_args.label_names = ["classifier_label"]
|
342
|
+
|
343
|
+
if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
|
344
|
+
config = LoraConfig(
|
345
|
+
r=model_args.lora_rank,
|
346
|
+
lora_alpha=model_args.lora_alpha,
|
347
|
+
target_modules=model_args.target_modules,
|
348
|
+
lora_dropout=model_args.lora_dropout,
|
349
|
+
bias="none",
|
350
|
+
modules_to_save=["classifier", "age_batch_norm", "dense_layer"],
|
351
|
+
)
|
352
|
+
model = get_peft_model(model, config)
|
353
|
+
else:
|
354
|
+
raise ValueError(
|
355
|
+
f"The LORA adapter is not supported for {model_args.finetune_model_type}"
|
356
|
+
)
|
357
|
+
return model
|
358
|
+
|
359
|
+
|
360
|
+
def main():
|
361
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
362
|
+
tokenizer = load_pretrained_tokenizer(model_args)
|
363
|
+
prepared_ds_path = generate_prepared_ds_path(
|
364
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
365
|
+
)
|
366
|
+
|
367
|
+
processed_dataset = None
|
368
|
+
if any(prepared_ds_path.glob("*")):
|
369
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
370
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
371
|
+
LOG.info("Prepared dataset loaded from disk...")
|
372
|
+
if cehrgpt_args.expand_tokenizer:
|
373
|
+
try:
|
374
|
+
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
375
|
+
except Exception:
|
376
|
+
LOG.warning(
|
377
|
+
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
378
|
+
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
379
|
+
f"Please delete the processed dataset at {prepared_ds_path}."
|
380
|
+
)
|
381
|
+
processed_dataset = None
|
382
|
+
shutil.rmtree(prepared_ds_path)
|
383
|
+
|
384
|
+
if processed_dataset is None:
|
385
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
386
|
+
if data_args.is_data_in_meds:
|
387
|
+
meds_extension_path = get_meds_extension_path(
|
388
|
+
data_folder=data_args.cohort_folder,
|
389
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
390
|
+
)
|
391
|
+
try:
|
392
|
+
LOG.info(
|
393
|
+
f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
|
394
|
+
)
|
395
|
+
dataset = load_from_disk(meds_extension_path)
|
396
|
+
if data_args.streaming:
|
397
|
+
if isinstance(dataset, DatasetDict):
|
398
|
+
dataset = {
|
399
|
+
k: v.to_iterable_dataset(
|
400
|
+
num_shards=training_args.dataloader_num_workers
|
401
|
+
)
|
402
|
+
for k, v in dataset.items()
|
403
|
+
}
|
404
|
+
else:
|
405
|
+
dataset = dataset.to_iterable_dataset(
|
406
|
+
num_shards=training_args.dataloader_num_workers
|
407
|
+
)
|
408
|
+
except Exception as e:
|
409
|
+
LOG.exception(e)
|
410
|
+
dataset = create_dataset_from_meds_reader(
|
411
|
+
data_args, is_pretraining=False
|
412
|
+
)
|
413
|
+
if not data_args.streaming:
|
414
|
+
dataset.save_to_disk(meds_extension_path)
|
415
|
+
train_set = dataset["train"]
|
416
|
+
validation_set = dataset["validation"]
|
417
|
+
test_set = dataset["test"]
|
418
|
+
else:
|
419
|
+
train_set, validation_set, test_set = create_dataset_splits(
|
420
|
+
data_args=data_args, seed=training_args.seed
|
421
|
+
)
|
422
|
+
# Organize them into a single DatasetDict
|
423
|
+
final_splits = DatasetDict(
|
424
|
+
{"train": train_set, "validation": validation_set, "test": test_set}
|
425
|
+
)
|
426
|
+
|
427
|
+
if cehrgpt_args.expand_tokenizer:
|
428
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
429
|
+
try:
|
430
|
+
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
431
|
+
except Exception:
|
432
|
+
# Try to use the defined pretrained embeddings if exists,
|
433
|
+
# Otherwise we default to the pretrained model embedded in the pretrained model
|
434
|
+
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
435
|
+
cehrgpt_args.pretrained_embedding_path
|
436
|
+
)
|
437
|
+
if not pretrained_concept_embedding_model.exists:
|
438
|
+
pretrained_concept_embedding_model = (
|
439
|
+
tokenizer.pretrained_concept_embedding_model
|
440
|
+
)
|
441
|
+
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
442
|
+
cehrgpt_tokenizer=tokenizer,
|
443
|
+
dataset=final_splits["train"],
|
444
|
+
data_args=data_args,
|
445
|
+
concept_name_mapping={},
|
446
|
+
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
447
|
+
)
|
448
|
+
tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
|
449
|
+
|
450
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
451
|
+
dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
|
452
|
+
)
|
453
|
+
if not data_args.streaming:
|
454
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
455
|
+
|
456
|
+
# Set seed before initializing model.
|
457
|
+
set_seed(training_args.seed)
|
458
|
+
|
459
|
+
processed_dataset.set_format("pt")
|
460
|
+
|
461
|
+
if cehrgpt_args.few_shot_predict:
|
462
|
+
# At least we need two examples to have a validation set for early stopping
|
463
|
+
num_shots = max(cehrgpt_args.n_shots, 2)
|
464
|
+
random_train_indices = random.sample(
|
465
|
+
range(len(processed_dataset["train"])), k=num_shots
|
466
|
+
)
|
467
|
+
test_size = max(int(num_shots * data_args.validation_split_percentage), 1)
|
468
|
+
few_shot_train_val_set = processed_dataset["train"].select(random_train_indices)
|
469
|
+
train_val = few_shot_train_val_set.train_test_split(
|
470
|
+
test_size=test_size, seed=training_args.seed
|
471
|
+
)
|
472
|
+
few_shot_train_set, few_shot_val_set = train_val["train"], train_val["test"]
|
473
|
+
processed_dataset["train"] = few_shot_train_set
|
474
|
+
processed_dataset["validation"] = few_shot_val_set
|
475
|
+
|
476
|
+
config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
|
477
|
+
if config.max_position_embeddings < model_args.max_position_embeddings:
|
478
|
+
config.max_position_embeddings = model_args.max_position_embeddings
|
479
|
+
# We suppress the additional learning objectives in fine-tuning
|
480
|
+
data_collator = CehrGptDataCollator(
|
481
|
+
tokenizer=tokenizer,
|
482
|
+
max_length=(
|
483
|
+
config.max_position_embeddings - 1
|
484
|
+
if config.causal_sfm
|
485
|
+
else config.max_position_embeddings
|
486
|
+
),
|
487
|
+
include_values=model_args.include_values,
|
488
|
+
pretraining=False,
|
489
|
+
include_ttv_prediction=False,
|
490
|
+
use_sub_time_tokenization=False,
|
491
|
+
include_demographics=cehrgpt_args.include_demographics,
|
492
|
+
)
|
493
|
+
|
494
|
+
if training_args.do_train:
|
495
|
+
if cehrgpt_args.hyperparameter_tuning:
|
496
|
+
model_args.early_stopping_patience = LARGE_INTEGER
|
497
|
+
training_args = perform_hyperparameter_search(
|
498
|
+
partial(model_init, model_args, training_args, tokenizer),
|
499
|
+
processed_dataset,
|
500
|
+
data_collator,
|
501
|
+
training_args,
|
502
|
+
model_args,
|
503
|
+
cehrgpt_args,
|
504
|
+
)
|
505
|
+
# Always retrain with the full set when hyperparameter tuning is set to true
|
506
|
+
retrain_with_full_set(
|
507
|
+
model_args, training_args, tokenizer, processed_dataset, data_collator
|
508
|
+
)
|
509
|
+
else:
|
510
|
+
# Initialize Trainer for final training on the combined train+val set
|
511
|
+
trainer = Trainer(
|
512
|
+
model=model_init(model_args, training_args, tokenizer),
|
513
|
+
data_collator=data_collator,
|
514
|
+
args=training_args,
|
515
|
+
train_dataset=processed_dataset["train"],
|
516
|
+
eval_dataset=processed_dataset["validation"],
|
517
|
+
callbacks=[
|
518
|
+
EarlyStoppingCallback(model_args.early_stopping_patience),
|
519
|
+
UpdateNumEpochsBeforeEarlyStoppingCallback(
|
520
|
+
training_args.output_dir
|
521
|
+
),
|
522
|
+
],
|
523
|
+
tokenizer=tokenizer,
|
524
|
+
)
|
525
|
+
# Train the model on the combined train + val set
|
526
|
+
checkpoint = get_last_hf_checkpoint(training_args)
|
527
|
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
528
|
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
529
|
+
metrics = train_result.metrics
|
530
|
+
trainer.log_metrics("train", metrics)
|
531
|
+
trainer.save_metrics("train", metrics)
|
532
|
+
trainer.save_state()
|
533
|
+
|
534
|
+
# Retrain the model with full set using the num of epoches before earlying stopping
|
535
|
+
if cehrgpt_args.retrain_with_full:
|
536
|
+
update_num_epoch_before_early_stopping_callback = None
|
537
|
+
for callback in trainer.callback_handler.callbacks:
|
538
|
+
if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
|
539
|
+
update_num_epoch_before_early_stopping_callback = callback
|
540
|
+
|
541
|
+
if update_num_epoch_before_early_stopping_callback is None:
|
542
|
+
raise RuntimeError(
|
543
|
+
f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
|
544
|
+
)
|
545
|
+
final_num_epochs = (
|
546
|
+
update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
|
547
|
+
)
|
548
|
+
training_args.num_train_epochs = final_num_epochs
|
549
|
+
LOG.info(
|
550
|
+
"Num Epochs before early stopping: %s",
|
551
|
+
training_args.num_train_epochs,
|
552
|
+
)
|
553
|
+
retrain_with_full_set(
|
554
|
+
model_args,
|
555
|
+
training_args,
|
556
|
+
tokenizer,
|
557
|
+
processed_dataset,
|
558
|
+
data_collator,
|
559
|
+
)
|
560
|
+
|
561
|
+
if training_args.do_predict:
|
562
|
+
test_dataloader = DataLoader(
|
563
|
+
dataset=processed_dataset["test"],
|
564
|
+
batch_size=training_args.per_device_eval_batch_size,
|
565
|
+
num_workers=training_args.dataloader_num_workers,
|
566
|
+
collate_fn=data_collator,
|
567
|
+
pin_memory=training_args.dataloader_pin_memory,
|
568
|
+
)
|
569
|
+
do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
|
570
|
+
|
571
|
+
|
572
|
+
def retrain_with_full_set(
|
573
|
+
model_args: ModelArguments,
|
574
|
+
training_args: TrainingArguments,
|
575
|
+
tokenizer: CehrGptTokenizer,
|
576
|
+
dataset: DatasetDict,
|
577
|
+
data_collator: CehrGptDataCollator,
|
578
|
+
) -> None:
|
579
|
+
"""
|
580
|
+
Retrains a model on the full training and validation dataset for final performance evaluation.
|
581
|
+
|
582
|
+
This function consolidates the training and validation datasets into a single
|
583
|
+
dataset for final model training, updates the output directory for the final model,
|
584
|
+
and disables evaluation during training. It resumes from the latest checkpoint if available,
|
585
|
+
trains the model on the combined dataset, and saves the model along with training metrics
|
586
|
+
and state information.
|
587
|
+
|
588
|
+
Args:
|
589
|
+
model_args (ModelArguments): Model configuration and hyperparameters.
|
590
|
+
training_args (TrainingArguments): Training configuration, including output directory,
|
591
|
+
evaluation strategy, and other training parameters.
|
592
|
+
tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
|
593
|
+
dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
|
594
|
+
data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
|
595
|
+
|
596
|
+
Returns:
|
597
|
+
None
|
598
|
+
"""
|
599
|
+
# Initialize Trainer for final training on the combined train+val set
|
600
|
+
full_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])
|
601
|
+
training_args.output_dir = os.path.join(training_args.output_dir, "full")
|
602
|
+
LOG.info(
|
603
|
+
"Final output_dir for final_training_args.output_dir %s",
|
604
|
+
training_args.output_dir,
|
605
|
+
)
|
606
|
+
Path(training_args.output_dir).mkdir(exist_ok=True)
|
607
|
+
# Disable evaluation
|
608
|
+
training_args.evaluation_strategy = "no"
|
609
|
+
checkpoint = get_last_hf_checkpoint(training_args)
|
610
|
+
final_trainer = Trainer(
|
611
|
+
model=model_init(model_args, training_args, tokenizer),
|
612
|
+
data_collator=data_collator,
|
613
|
+
args=training_args,
|
614
|
+
train_dataset=full_dataset,
|
615
|
+
tokenizer=tokenizer,
|
616
|
+
)
|
617
|
+
final_train_result = final_trainer.train(resume_from_checkpoint=checkpoint)
|
618
|
+
final_trainer.save_model() # Saves the tokenizer too for easy upload
|
619
|
+
metrics = final_train_result.metrics
|
620
|
+
final_trainer.log_metrics("train", metrics)
|
621
|
+
final_trainer.save_metrics("train", metrics)
|
622
|
+
final_trainer.save_state()
|
623
|
+
|
624
|
+
|
625
|
+
def do_predict(
|
626
|
+
test_dataloader: DataLoader,
|
627
|
+
model_args: ModelArguments,
|
628
|
+
training_args: TrainingArguments,
|
629
|
+
cehrgpt_args: CehrGPTArguments,
|
630
|
+
):
|
631
|
+
"""
|
632
|
+
Performs inference on the test dataset using a fine-tuned model, saves predictions and evaluation metrics.
|
633
|
+
|
634
|
+
The reason we created this custom do_predict is that there is a memory leakage for transformers trainer.predict(),
|
635
|
+
for large test sets, it will throw the CPU OOM error
|
636
|
+
|
637
|
+
Args:
|
638
|
+
test_dataloader (DataLoader): DataLoader containing the test dataset, with batches of input features and labels.
|
639
|
+
model_args (ModelArguments): Arguments for configuring and loading the fine-tuned model.
|
640
|
+
training_args (TrainingArguments): Arguments related to training, evaluation, and output directories.
|
641
|
+
cehrgpt_args (CehrGPTArguments):
|
642
|
+
Returns:
|
643
|
+
None. Results are saved to disk.
|
644
|
+
"""
|
645
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
646
|
+
|
647
|
+
# Load model and LoRA adapters if applicable
|
648
|
+
model = (
|
649
|
+
load_finetuned_model(model_args, training_args, training_args.output_dir)
|
650
|
+
if not model_args.use_lora
|
651
|
+
else load_lora_model(model_args, training_args, cehrgpt_args)
|
652
|
+
)
|
653
|
+
|
654
|
+
model = model.to(device).eval()
|
655
|
+
|
656
|
+
# Ensure prediction folder exists
|
657
|
+
test_prediction_folder = Path(training_args.output_dir) / "test_predictions"
|
658
|
+
test_prediction_folder.mkdir(parents=True, exist_ok=True)
|
659
|
+
|
660
|
+
LOG.info("Generating predictions for test set at %s", test_prediction_folder)
|
661
|
+
|
662
|
+
test_losses = []
|
663
|
+
with torch.no_grad():
|
664
|
+
for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
|
665
|
+
person_ids = batch.pop("person_id").numpy().squeeze().astype(int)
|
666
|
+
index_dates = (
|
667
|
+
map(
|
668
|
+
datetime.fromtimestamp,
|
669
|
+
batch.pop("index_date").numpy().squeeze(axis=-1).tolist(),
|
670
|
+
)
|
671
|
+
if "index_date" in batch
|
672
|
+
else None
|
673
|
+
)
|
674
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
675
|
+
# Forward pass
|
676
|
+
output = model(**batch, output_attentions=False, output_hidden_states=False)
|
677
|
+
test_losses.append(output.loss.item())
|
678
|
+
|
679
|
+
# Collect logits and labels for prediction
|
680
|
+
logits = output.logits.float().cpu().numpy().squeeze()
|
681
|
+
labels = (
|
682
|
+
batch["classifier_label"].float().cpu().numpy().squeeze().astype(bool)
|
683
|
+
)
|
684
|
+
probabilities = sigmoid(logits)
|
685
|
+
# Save predictions to parquet file
|
686
|
+
test_prediction_pd = pd.DataFrame(
|
687
|
+
{
|
688
|
+
"subject_id": person_ids,
|
689
|
+
"prediction_time": index_dates,
|
690
|
+
"boolean_prediction_probability": probabilities,
|
691
|
+
"boolean_prediction": logits,
|
692
|
+
"boolean_value": labels,
|
693
|
+
}
|
694
|
+
)
|
695
|
+
test_prediction_pd.to_parquet(test_prediction_folder / f"{index}.parquet")
|
696
|
+
|
697
|
+
LOG.info(
|
698
|
+
"Computing metrics using the test set predictions at %s", test_prediction_folder
|
699
|
+
)
|
700
|
+
# Load all predictions
|
701
|
+
test_prediction_pd = pd.read_parquet(test_prediction_folder)
|
702
|
+
# Compute metrics and save results
|
703
|
+
metrics = compute_metrics(
|
704
|
+
references=test_prediction_pd.boolean_value,
|
705
|
+
probs=test_prediction_pd.boolean_prediction_probability,
|
706
|
+
)
|
707
|
+
metrics["test_loss"] = np.mean(test_losses)
|
708
|
+
|
709
|
+
test_results_path = Path(training_args.output_dir) / "test_results.json"
|
710
|
+
with open(test_results_path, "w") as f:
|
711
|
+
json.dump(metrics, f, indent=4)
|
712
|
+
|
713
|
+
LOG.info("Test results: %s", metrics)
|
714
|
+
|
715
|
+
|
716
|
+
def load_lora_model(
|
717
|
+
model_args: ModelArguments,
|
718
|
+
training_args: TrainingArguments,
|
719
|
+
cehrgpt_args: CehrGPTArguments,
|
720
|
+
) -> PeftModel:
|
721
|
+
LOG.info("Loading base model from %s", model_args.model_name_or_path)
|
722
|
+
model = load_finetuned_model(
|
723
|
+
model_args, training_args, model_args.model_name_or_path
|
724
|
+
)
|
725
|
+
# Enable include_values when include_values is set to be False during pre-training
|
726
|
+
if model_args.include_values and not model.cehrgpt.include_values:
|
727
|
+
model.cehrgpt.include_values = True
|
728
|
+
# Enable position embeddings when position embeddings are disabled in pre-training
|
729
|
+
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
730
|
+
model.cehrgpt.exclude_position_ids = False
|
731
|
+
if cehrgpt_args.expand_tokenizer:
|
732
|
+
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
733
|
+
# Expand tokenizer to adapt to the finetuning dataset
|
734
|
+
if model.config.vocab_size < tokenizer.vocab_size:
|
735
|
+
model.resize_token_embeddings(tokenizer.vocab_size)
|
736
|
+
if (
|
737
|
+
model.config.include_values
|
738
|
+
and model.config.value_vocab_size < tokenizer.value_vocab_size
|
739
|
+
):
|
740
|
+
model.resize_value_embeddings(tokenizer.value_vocab_size)
|
741
|
+
LOG.info("Loading LoRA adapter from %s", training_args.output_dir)
|
742
|
+
return PeftModel.from_pretrained(model, model_id=training_args.output_dir)
|
743
|
+
|
744
|
+
|
745
|
+
if __name__ == "__main__":
|
746
|
+
main()
|