cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,746 @@
1
+ import json
2
+ import os
3
+ import random
4
+ import shutil
5
+ from datetime import datetime
6
+ from functools import partial
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from cehrbert.data_generators.hf_data_generator.meds_utils import (
13
+ create_dataset_from_meds_reader,
14
+ )
15
+ from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
16
+ from cehrbert.runners.hf_runner_argument_dataclass import (
17
+ DataTrainingArguments,
18
+ FineTuneModelType,
19
+ ModelArguments,
20
+ )
21
+ from cehrbert.runners.runner_util import (
22
+ generate_prepared_ds_path,
23
+ get_last_hf_checkpoint,
24
+ get_meds_extension_path,
25
+ load_parquet_as_dataset,
26
+ )
27
+ from datasets import DatasetDict, concatenate_datasets, load_from_disk
28
+ from peft import LoraConfig, PeftModel, get_peft_model
29
+ from scipy.special import expit as sigmoid
30
+ from torch.utils.data import DataLoader
31
+ from tqdm import tqdm
32
+ from transformers import (
33
+ EarlyStoppingCallback,
34
+ Trainer,
35
+ TrainerCallback,
36
+ TrainerControl,
37
+ TrainerState,
38
+ TrainingArguments,
39
+ set_seed,
40
+ )
41
+ from transformers.tokenization_utils_base import LARGE_INTEGER
42
+ from transformers.utils import is_flash_attn_2_available, logging
43
+
44
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
45
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
46
+ from cehrgpt.models.hf_cehrgpt import (
47
+ CEHRGPTConfig,
48
+ CehrGptForClassification,
49
+ CEHRGPTPreTrainedModel,
50
+ )
51
+ from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
52
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
53
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
54
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
55
+ from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
56
+
57
+ LOG = logging.get_logger("transformers")
58
+
59
+
60
+ class UpdateNumEpochsBeforeEarlyStoppingCallback(TrainerCallback):
61
+ """
62
+ Callback to update metrics with the number of epochs completed before early stopping.
63
+
64
+ based on the best evaluation metric (e.g., eval_loss).
65
+ """
66
+
67
+ def __init__(self, model_folder: str):
68
+ self._model_folder = model_folder
69
+ self._metrics_path = os.path.join(
70
+ model_folder, "num_epochs_trained_before_early_stopping.json"
71
+ )
72
+ self._num_epochs_before_early_stopping = 0
73
+ self._best_val_loss = float("inf")
74
+
75
+ @property
76
+ def num_epochs_before_early_stopping(self):
77
+ return self._num_epochs_before_early_stopping
78
+
79
+ def on_train_begin(
80
+ self,
81
+ args: TrainingArguments,
82
+ state: TrainerState,
83
+ control: TrainerControl,
84
+ **kwargs,
85
+ ):
86
+ if os.path.exists(self._metrics_path):
87
+ with open(self._metrics_path, "r") as f:
88
+ metrics = json.load(f)
89
+ self._num_epochs_before_early_stopping = metrics[
90
+ "num_epochs_before_early_stopping"
91
+ ]
92
+ self._best_val_loss = metrics["best_val_loss"]
93
+
94
+ def on_evaluate(self, args, state, control, **kwargs):
95
+ # Ensure metrics is available in kwargs
96
+ metrics = kwargs.get("metrics")
97
+ if metrics is not None and "eval_loss" in metrics:
98
+ # Check and update if a new best metric is achieved
99
+ if metrics["eval_loss"] < self._best_val_loss:
100
+ self._num_epochs_before_early_stopping = round(state.epoch)
101
+ self._best_val_loss = metrics["eval_loss"]
102
+
103
+ def on_save(
104
+ self,
105
+ args: TrainingArguments,
106
+ state: TrainerState,
107
+ control: TrainerControl,
108
+ **kwargs,
109
+ ):
110
+ with open(self._metrics_path, "w") as f:
111
+ json.dump(
112
+ {
113
+ "num_epochs_before_early_stopping": self._num_epochs_before_early_stopping,
114
+ "best_val_loss": self._best_val_loss,
115
+ },
116
+ f,
117
+ )
118
+
119
+
120
+ def load_pretrained_tokenizer(
121
+ model_args,
122
+ ) -> CehrGptTokenizer:
123
+ try:
124
+ return CehrGptTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
125
+ except Exception:
126
+ raise ValueError(
127
+ f"Can not load the pretrained tokenizer from {model_args.tokenizer_name_or_path}"
128
+ )
129
+
130
+
131
+ def load_finetuned_model(
132
+ model_args: ModelArguments,
133
+ training_args: TrainingArguments,
134
+ model_name_or_path: str,
135
+ ) -> CEHRGPTPreTrainedModel:
136
+ if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
137
+ finetune_model_cls = CehrGptForClassification
138
+ else:
139
+ raise ValueError(
140
+ f"finetune_model_type can be one of the following types {FineTuneModelType.POOLING.value}"
141
+ )
142
+
143
+ attn_implementation = (
144
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
145
+ )
146
+ torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
147
+ # Try to create a new model based on the base model
148
+ try:
149
+ return finetune_model_cls.from_pretrained(
150
+ model_name_or_path,
151
+ attn_implementation=attn_implementation,
152
+ torch_dtype=torch_dtype,
153
+ )
154
+ except ValueError:
155
+ raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
156
+
157
+
158
+ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
159
+ """
160
+ Creates training, validation, and testing dataset splits based on specified splitting strategies.
161
+
162
+ This function splits a dataset into training, validation, and test sets, using either chronological,
163
+ patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
164
+
165
+ - **Chronological split**: Sorts by a specified date and splits based on historical and future data.
166
+ - **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
167
+ - **Random split**: Performs a straightforward random split of the dataset.
168
+
169
+ If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
170
+ the test set is created by further splitting the validation set based on `test_eval_ratio`.
171
+
172
+ Parameters:
173
+ data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
174
+ - `data_folder` (str): Path to the main dataset.
175
+ - `test_data_folder` (str, optional): Path to an optional test dataset.
176
+ - `chronological_split` (bool): Whether to split chronologically.
177
+ - `split_by_patient` (bool): Whether to split by unique patient IDs.
178
+ - `validation_split_percentage` (float): Percentage of data to use for validation.
179
+ - `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
180
+ - `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
181
+ - `preprocessing_batch_size` (int): Batch size for batched operations.
182
+ seed (int): Random seed for reproducibility of splits.
183
+
184
+ Returns:
185
+ Tuple[Dataset, Dataset, Dataset]: A tuple containing:
186
+ - `train_set` (Dataset): Training split of the dataset.
187
+ - `validation_set` (Dataset): Validation split of the dataset.
188
+ - `test_set` (Dataset): Test split of the dataset.
189
+
190
+ Raises:
191
+ FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
192
+ ValueError: If incompatible arguments are passed for splitting strategies.
193
+
194
+ Example Usage:
195
+ data_args = DataTrainingArguments(
196
+ data_folder="data/",
197
+ validation_split_percentage=0.1,
198
+ test_eval_ratio=0.2,
199
+ chronological_split=True
200
+ )
201
+ train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
202
+ """
203
+ dataset = load_parquet_as_dataset(data_args.data_folder)
204
+ test_set = (
205
+ None
206
+ if not data_args.test_data_folder
207
+ else load_parquet_as_dataset(data_args.test_data_folder)
208
+ )
209
+
210
+ if data_args.chronological_split:
211
+ # Chronological split by sorting on `index_date`
212
+ dataset = dataset.sort("index_date")
213
+ total_size = len(dataset)
214
+ train_end = int((1 - data_args.validation_split_percentage) * total_size)
215
+
216
+ # Perform the split
217
+ train_set = dataset.select(range(0, train_end))
218
+ validation_set = dataset.select(range(train_end, total_size))
219
+
220
+ if test_set is None:
221
+ test_valid_split = validation_set.train_test_split(
222
+ test_size=data_args.test_eval_ratio, seed=seed
223
+ )
224
+ validation_set, test_set = (
225
+ test_valid_split["train"],
226
+ test_valid_split["test"],
227
+ )
228
+
229
+ elif data_args.split_by_patient:
230
+ # Patient-based split
231
+ LOG.info("Using the split_by_patient strategy")
232
+ unique_patient_ids = dataset.unique("person_id")
233
+ LOG.info(f"There are {len(unique_patient_ids)} patients in total")
234
+
235
+ np.random.seed(seed)
236
+ np.random.shuffle(unique_patient_ids)
237
+
238
+ train_end = int(
239
+ len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
240
+ )
241
+ train_patient_ids = set(unique_patient_ids[:train_end])
242
+
243
+ if test_set is None:
244
+ validation_end = int(
245
+ train_end
246
+ + len(unique_patient_ids)
247
+ * data_args.validation_split_percentage
248
+ * data_args.test_eval_ratio
249
+ )
250
+ val_patient_ids = set(unique_patient_ids[train_end:validation_end])
251
+ test_patient_ids = set(unique_patient_ids[validation_end:])
252
+ else:
253
+ val_patient_ids, test_patient_ids = (
254
+ set(unique_patient_ids[train_end:]),
255
+ None,
256
+ )
257
+
258
+ # Helper function to apply patient-based filtering
259
+ def filter_by_patient_ids(patient_ids):
260
+ return dataset.filter(
261
+ lambda batch: [pid in patient_ids for pid in batch["person_id"]],
262
+ num_proc=data_args.preprocessing_num_workers,
263
+ batched=True,
264
+ batch_size=data_args.preprocessing_batch_size,
265
+ )
266
+
267
+ # Generate splits
268
+ train_set = filter_by_patient_ids(train_patient_ids)
269
+ validation_set = filter_by_patient_ids(val_patient_ids)
270
+ if test_set is None:
271
+ test_set = filter_by_patient_ids(test_patient_ids)
272
+
273
+ else:
274
+ # Random split
275
+ train_val = dataset.train_test_split(
276
+ test_size=data_args.validation_split_percentage, seed=seed
277
+ )
278
+ train_set, validation_set = train_val["train"], train_val["test"]
279
+
280
+ if test_set is None:
281
+ test_valid_split = validation_set.train_test_split(
282
+ test_size=data_args.test_eval_ratio, seed=seed
283
+ )
284
+ validation_set, test_set = (
285
+ test_valid_split["train"],
286
+ test_valid_split["test"],
287
+ )
288
+
289
+ return train_set, validation_set, test_set
290
+
291
+
292
+ def model_init(
293
+ model_args: ModelArguments,
294
+ training_args: TrainingArguments,
295
+ tokenizer: CehrGptTokenizer,
296
+ ):
297
+ model = load_finetuned_model(
298
+ model_args, training_args, model_args.model_name_or_path
299
+ )
300
+ if model.config.max_position_embeddings < model_args.max_position_embeddings:
301
+ LOG.info(
302
+ f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
303
+ )
304
+ model.config.max_position_embeddings = model_args.max_position_embeddings
305
+ model.resize_position_embeddings(model_args.max_position_embeddings)
306
+ # Enable include_values when include_values is set to be False during pre-training
307
+ if model_args.include_values and not model.cehrgpt.include_values:
308
+ model.cehrgpt.include_values = True
309
+ # Enable position embeddings when position embeddings are disabled in pre-training
310
+ if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
311
+ model.cehrgpt.exclude_position_ids = False
312
+ # Expand tokenizer to adapt to the finetuning dataset
313
+ if model.config.vocab_size < tokenizer.vocab_size:
314
+ model.resize_token_embeddings(tokenizer.vocab_size)
315
+ # Update the pretrained embedding weights if they are available
316
+ if model.config.use_pretrained_embeddings:
317
+ model.cehrgpt.update_pretrained_embeddings(
318
+ tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
319
+ )
320
+ elif tokenizer.pretrained_token_ids:
321
+ model.config.pretrained_embedding_dim = (
322
+ tokenizer.pretrained_embeddings.shape[1]
323
+ )
324
+ model.config.use_pretrained_embeddings = True
325
+ model.cehrgpt.initialize_pretrained_embeddings()
326
+ model.cehrgpt.update_pretrained_embeddings(
327
+ tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
328
+ )
329
+ # Expand value tokenizer to adapt to the fine-tuning dataset
330
+ if model.config.include_values:
331
+ if model.config.value_vocab_size < tokenizer.value_vocab_size:
332
+ model.resize_value_embeddings(tokenizer.value_vocab_size)
333
+ # If lora is enabled, we add LORA adapters to the model
334
+ if model_args.use_lora:
335
+ # When LORA is used, the trainer could not automatically find this label,
336
+ # therefore we need to manually set label_names to "classifier_label" so the model
337
+ # can compute the loss during the evaluation
338
+ if training_args.label_names:
339
+ training_args.label_names.append("classifier_label")
340
+ else:
341
+ training_args.label_names = ["classifier_label"]
342
+
343
+ if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
344
+ config = LoraConfig(
345
+ r=model_args.lora_rank,
346
+ lora_alpha=model_args.lora_alpha,
347
+ target_modules=model_args.target_modules,
348
+ lora_dropout=model_args.lora_dropout,
349
+ bias="none",
350
+ modules_to_save=["classifier", "age_batch_norm", "dense_layer"],
351
+ )
352
+ model = get_peft_model(model, config)
353
+ else:
354
+ raise ValueError(
355
+ f"The LORA adapter is not supported for {model_args.finetune_model_type}"
356
+ )
357
+ return model
358
+
359
+
360
+ def main():
361
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
362
+ tokenizer = load_pretrained_tokenizer(model_args)
363
+ prepared_ds_path = generate_prepared_ds_path(
364
+ data_args, model_args, data_folder=data_args.cohort_folder
365
+ )
366
+
367
+ processed_dataset = None
368
+ if any(prepared_ds_path.glob("*")):
369
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
370
+ processed_dataset = load_from_disk(str(prepared_ds_path))
371
+ LOG.info("Prepared dataset loaded from disk...")
372
+ if cehrgpt_args.expand_tokenizer:
373
+ try:
374
+ tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
375
+ except Exception:
376
+ LOG.warning(
377
+ f"CehrGptTokenizer must exist in {training_args.output_dir} "
378
+ f"when the dataset has been processed and expand_tokenizer is set to True. "
379
+ f"Please delete the processed dataset at {prepared_ds_path}."
380
+ )
381
+ processed_dataset = None
382
+ shutil.rmtree(prepared_ds_path)
383
+
384
+ if processed_dataset is None:
385
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
386
+ if data_args.is_data_in_meds:
387
+ meds_extension_path = get_meds_extension_path(
388
+ data_folder=data_args.cohort_folder,
389
+ dataset_prepared_path=data_args.dataset_prepared_path,
390
+ )
391
+ try:
392
+ LOG.info(
393
+ f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
394
+ )
395
+ dataset = load_from_disk(meds_extension_path)
396
+ if data_args.streaming:
397
+ if isinstance(dataset, DatasetDict):
398
+ dataset = {
399
+ k: v.to_iterable_dataset(
400
+ num_shards=training_args.dataloader_num_workers
401
+ )
402
+ for k, v in dataset.items()
403
+ }
404
+ else:
405
+ dataset = dataset.to_iterable_dataset(
406
+ num_shards=training_args.dataloader_num_workers
407
+ )
408
+ except Exception as e:
409
+ LOG.exception(e)
410
+ dataset = create_dataset_from_meds_reader(
411
+ data_args, is_pretraining=False
412
+ )
413
+ if not data_args.streaming:
414
+ dataset.save_to_disk(meds_extension_path)
415
+ train_set = dataset["train"]
416
+ validation_set = dataset["validation"]
417
+ test_set = dataset["test"]
418
+ else:
419
+ train_set, validation_set, test_set = create_dataset_splits(
420
+ data_args=data_args, seed=training_args.seed
421
+ )
422
+ # Organize them into a single DatasetDict
423
+ final_splits = DatasetDict(
424
+ {"train": train_set, "validation": validation_set, "test": test_set}
425
+ )
426
+
427
+ if cehrgpt_args.expand_tokenizer:
428
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
429
+ try:
430
+ tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
431
+ except Exception:
432
+ # Try to use the defined pretrained embeddings if exists,
433
+ # Otherwise we default to the pretrained model embedded in the pretrained model
434
+ pretrained_concept_embedding_model = PretrainedEmbeddings(
435
+ cehrgpt_args.pretrained_embedding_path
436
+ )
437
+ if not pretrained_concept_embedding_model.exists:
438
+ pretrained_concept_embedding_model = (
439
+ tokenizer.pretrained_concept_embedding_model
440
+ )
441
+ tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
442
+ cehrgpt_tokenizer=tokenizer,
443
+ dataset=final_splits["train"],
444
+ data_args=data_args,
445
+ concept_name_mapping={},
446
+ pretrained_concept_embedding_model=pretrained_concept_embedding_model,
447
+ )
448
+ tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
449
+
450
+ processed_dataset = create_cehrgpt_finetuning_dataset(
451
+ dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
452
+ )
453
+ if not data_args.streaming:
454
+ processed_dataset.save_to_disk(prepared_ds_path)
455
+
456
+ # Set seed before initializing model.
457
+ set_seed(training_args.seed)
458
+
459
+ processed_dataset.set_format("pt")
460
+
461
+ if cehrgpt_args.few_shot_predict:
462
+ # At least we need two examples to have a validation set for early stopping
463
+ num_shots = max(cehrgpt_args.n_shots, 2)
464
+ random_train_indices = random.sample(
465
+ range(len(processed_dataset["train"])), k=num_shots
466
+ )
467
+ test_size = max(int(num_shots * data_args.validation_split_percentage), 1)
468
+ few_shot_train_val_set = processed_dataset["train"].select(random_train_indices)
469
+ train_val = few_shot_train_val_set.train_test_split(
470
+ test_size=test_size, seed=training_args.seed
471
+ )
472
+ few_shot_train_set, few_shot_val_set = train_val["train"], train_val["test"]
473
+ processed_dataset["train"] = few_shot_train_set
474
+ processed_dataset["validation"] = few_shot_val_set
475
+
476
+ config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
477
+ if config.max_position_embeddings < model_args.max_position_embeddings:
478
+ config.max_position_embeddings = model_args.max_position_embeddings
479
+ # We suppress the additional learning objectives in fine-tuning
480
+ data_collator = CehrGptDataCollator(
481
+ tokenizer=tokenizer,
482
+ max_length=(
483
+ config.max_position_embeddings - 1
484
+ if config.causal_sfm
485
+ else config.max_position_embeddings
486
+ ),
487
+ include_values=model_args.include_values,
488
+ pretraining=False,
489
+ include_ttv_prediction=False,
490
+ use_sub_time_tokenization=False,
491
+ include_demographics=cehrgpt_args.include_demographics,
492
+ )
493
+
494
+ if training_args.do_train:
495
+ if cehrgpt_args.hyperparameter_tuning:
496
+ model_args.early_stopping_patience = LARGE_INTEGER
497
+ training_args = perform_hyperparameter_search(
498
+ partial(model_init, model_args, training_args, tokenizer),
499
+ processed_dataset,
500
+ data_collator,
501
+ training_args,
502
+ model_args,
503
+ cehrgpt_args,
504
+ )
505
+ # Always retrain with the full set when hyperparameter tuning is set to true
506
+ retrain_with_full_set(
507
+ model_args, training_args, tokenizer, processed_dataset, data_collator
508
+ )
509
+ else:
510
+ # Initialize Trainer for final training on the combined train+val set
511
+ trainer = Trainer(
512
+ model=model_init(model_args, training_args, tokenizer),
513
+ data_collator=data_collator,
514
+ args=training_args,
515
+ train_dataset=processed_dataset["train"],
516
+ eval_dataset=processed_dataset["validation"],
517
+ callbacks=[
518
+ EarlyStoppingCallback(model_args.early_stopping_patience),
519
+ UpdateNumEpochsBeforeEarlyStoppingCallback(
520
+ training_args.output_dir
521
+ ),
522
+ ],
523
+ tokenizer=tokenizer,
524
+ )
525
+ # Train the model on the combined train + val set
526
+ checkpoint = get_last_hf_checkpoint(training_args)
527
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
528
+ trainer.save_model() # Saves the tokenizer too for easy upload
529
+ metrics = train_result.metrics
530
+ trainer.log_metrics("train", metrics)
531
+ trainer.save_metrics("train", metrics)
532
+ trainer.save_state()
533
+
534
+ # Retrain the model with full set using the num of epoches before earlying stopping
535
+ if cehrgpt_args.retrain_with_full:
536
+ update_num_epoch_before_early_stopping_callback = None
537
+ for callback in trainer.callback_handler.callbacks:
538
+ if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
539
+ update_num_epoch_before_early_stopping_callback = callback
540
+
541
+ if update_num_epoch_before_early_stopping_callback is None:
542
+ raise RuntimeError(
543
+ f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
544
+ )
545
+ final_num_epochs = (
546
+ update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
547
+ )
548
+ training_args.num_train_epochs = final_num_epochs
549
+ LOG.info(
550
+ "Num Epochs before early stopping: %s",
551
+ training_args.num_train_epochs,
552
+ )
553
+ retrain_with_full_set(
554
+ model_args,
555
+ training_args,
556
+ tokenizer,
557
+ processed_dataset,
558
+ data_collator,
559
+ )
560
+
561
+ if training_args.do_predict:
562
+ test_dataloader = DataLoader(
563
+ dataset=processed_dataset["test"],
564
+ batch_size=training_args.per_device_eval_batch_size,
565
+ num_workers=training_args.dataloader_num_workers,
566
+ collate_fn=data_collator,
567
+ pin_memory=training_args.dataloader_pin_memory,
568
+ )
569
+ do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
570
+
571
+
572
+ def retrain_with_full_set(
573
+ model_args: ModelArguments,
574
+ training_args: TrainingArguments,
575
+ tokenizer: CehrGptTokenizer,
576
+ dataset: DatasetDict,
577
+ data_collator: CehrGptDataCollator,
578
+ ) -> None:
579
+ """
580
+ Retrains a model on the full training and validation dataset for final performance evaluation.
581
+
582
+ This function consolidates the training and validation datasets into a single
583
+ dataset for final model training, updates the output directory for the final model,
584
+ and disables evaluation during training. It resumes from the latest checkpoint if available,
585
+ trains the model on the combined dataset, and saves the model along with training metrics
586
+ and state information.
587
+
588
+ Args:
589
+ model_args (ModelArguments): Model configuration and hyperparameters.
590
+ training_args (TrainingArguments): Training configuration, including output directory,
591
+ evaluation strategy, and other training parameters.
592
+ tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
593
+ dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
594
+ data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
595
+
596
+ Returns:
597
+ None
598
+ """
599
+ # Initialize Trainer for final training on the combined train+val set
600
+ full_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])
601
+ training_args.output_dir = os.path.join(training_args.output_dir, "full")
602
+ LOG.info(
603
+ "Final output_dir for final_training_args.output_dir %s",
604
+ training_args.output_dir,
605
+ )
606
+ Path(training_args.output_dir).mkdir(exist_ok=True)
607
+ # Disable evaluation
608
+ training_args.evaluation_strategy = "no"
609
+ checkpoint = get_last_hf_checkpoint(training_args)
610
+ final_trainer = Trainer(
611
+ model=model_init(model_args, training_args, tokenizer),
612
+ data_collator=data_collator,
613
+ args=training_args,
614
+ train_dataset=full_dataset,
615
+ tokenizer=tokenizer,
616
+ )
617
+ final_train_result = final_trainer.train(resume_from_checkpoint=checkpoint)
618
+ final_trainer.save_model() # Saves the tokenizer too for easy upload
619
+ metrics = final_train_result.metrics
620
+ final_trainer.log_metrics("train", metrics)
621
+ final_trainer.save_metrics("train", metrics)
622
+ final_trainer.save_state()
623
+
624
+
625
+ def do_predict(
626
+ test_dataloader: DataLoader,
627
+ model_args: ModelArguments,
628
+ training_args: TrainingArguments,
629
+ cehrgpt_args: CehrGPTArguments,
630
+ ):
631
+ """
632
+ Performs inference on the test dataset using a fine-tuned model, saves predictions and evaluation metrics.
633
+
634
+ The reason we created this custom do_predict is that there is a memory leakage for transformers trainer.predict(),
635
+ for large test sets, it will throw the CPU OOM error
636
+
637
+ Args:
638
+ test_dataloader (DataLoader): DataLoader containing the test dataset, with batches of input features and labels.
639
+ model_args (ModelArguments): Arguments for configuring and loading the fine-tuned model.
640
+ training_args (TrainingArguments): Arguments related to training, evaluation, and output directories.
641
+ cehrgpt_args (CehrGPTArguments):
642
+ Returns:
643
+ None. Results are saved to disk.
644
+ """
645
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
646
+
647
+ # Load model and LoRA adapters if applicable
648
+ model = (
649
+ load_finetuned_model(model_args, training_args, training_args.output_dir)
650
+ if not model_args.use_lora
651
+ else load_lora_model(model_args, training_args, cehrgpt_args)
652
+ )
653
+
654
+ model = model.to(device).eval()
655
+
656
+ # Ensure prediction folder exists
657
+ test_prediction_folder = Path(training_args.output_dir) / "test_predictions"
658
+ test_prediction_folder.mkdir(parents=True, exist_ok=True)
659
+
660
+ LOG.info("Generating predictions for test set at %s", test_prediction_folder)
661
+
662
+ test_losses = []
663
+ with torch.no_grad():
664
+ for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
665
+ person_ids = batch.pop("person_id").numpy().squeeze().astype(int)
666
+ index_dates = (
667
+ map(
668
+ datetime.fromtimestamp,
669
+ batch.pop("index_date").numpy().squeeze(axis=-1).tolist(),
670
+ )
671
+ if "index_date" in batch
672
+ else None
673
+ )
674
+ batch = {k: v.to(device) for k, v in batch.items()}
675
+ # Forward pass
676
+ output = model(**batch, output_attentions=False, output_hidden_states=False)
677
+ test_losses.append(output.loss.item())
678
+
679
+ # Collect logits and labels for prediction
680
+ logits = output.logits.float().cpu().numpy().squeeze()
681
+ labels = (
682
+ batch["classifier_label"].float().cpu().numpy().squeeze().astype(bool)
683
+ )
684
+ probabilities = sigmoid(logits)
685
+ # Save predictions to parquet file
686
+ test_prediction_pd = pd.DataFrame(
687
+ {
688
+ "subject_id": person_ids,
689
+ "prediction_time": index_dates,
690
+ "boolean_prediction_probability": probabilities,
691
+ "boolean_prediction": logits,
692
+ "boolean_value": labels,
693
+ }
694
+ )
695
+ test_prediction_pd.to_parquet(test_prediction_folder / f"{index}.parquet")
696
+
697
+ LOG.info(
698
+ "Computing metrics using the test set predictions at %s", test_prediction_folder
699
+ )
700
+ # Load all predictions
701
+ test_prediction_pd = pd.read_parquet(test_prediction_folder)
702
+ # Compute metrics and save results
703
+ metrics = compute_metrics(
704
+ references=test_prediction_pd.boolean_value,
705
+ probs=test_prediction_pd.boolean_prediction_probability,
706
+ )
707
+ metrics["test_loss"] = np.mean(test_losses)
708
+
709
+ test_results_path = Path(training_args.output_dir) / "test_results.json"
710
+ with open(test_results_path, "w") as f:
711
+ json.dump(metrics, f, indent=4)
712
+
713
+ LOG.info("Test results: %s", metrics)
714
+
715
+
716
+ def load_lora_model(
717
+ model_args: ModelArguments,
718
+ training_args: TrainingArguments,
719
+ cehrgpt_args: CehrGPTArguments,
720
+ ) -> PeftModel:
721
+ LOG.info("Loading base model from %s", model_args.model_name_or_path)
722
+ model = load_finetuned_model(
723
+ model_args, training_args, model_args.model_name_or_path
724
+ )
725
+ # Enable include_values when include_values is set to be False during pre-training
726
+ if model_args.include_values and not model.cehrgpt.include_values:
727
+ model.cehrgpt.include_values = True
728
+ # Enable position embeddings when position embeddings are disabled in pre-training
729
+ if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
730
+ model.cehrgpt.exclude_position_ids = False
731
+ if cehrgpt_args.expand_tokenizer:
732
+ tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
733
+ # Expand tokenizer to adapt to the finetuning dataset
734
+ if model.config.vocab_size < tokenizer.vocab_size:
735
+ model.resize_token_embeddings(tokenizer.vocab_size)
736
+ if (
737
+ model.config.include_values
738
+ and model.config.value_vocab_size < tokenizer.value_vocab_size
739
+ ):
740
+ model.resize_value_embeddings(tokenizer.value_vocab_size)
741
+ LOG.info("Loading LoRA adapter from %s", training_args.output_dir)
742
+ return PeftModel.from_pretrained(model, model_id=training_args.output_dir)
743
+
744
+
745
+ if __name__ == "__main__":
746
+ main()