cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,370 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Optional, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
6
|
+
create_dataset_from_meds_reader,
|
7
|
+
)
|
8
|
+
from cehrbert.runners.hf_runner_argument_dataclass import (
|
9
|
+
DataTrainingArguments,
|
10
|
+
ModelArguments,
|
11
|
+
)
|
12
|
+
from cehrbert.runners.runner_util import (
|
13
|
+
generate_prepared_ds_path,
|
14
|
+
get_last_hf_checkpoint,
|
15
|
+
get_meds_extension_path,
|
16
|
+
load_parquet_as_dataset,
|
17
|
+
)
|
18
|
+
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
|
19
|
+
from transformers import AutoConfig, Trainer, TrainingArguments, set_seed
|
20
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
21
|
+
|
22
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
|
23
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
24
|
+
from cehrgpt.models.config import CEHRGPTConfig
|
25
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
26
|
+
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
27
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
28
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
29
|
+
from src.cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
30
|
+
|
31
|
+
LOG = logging.get_logger("transformers")
|
32
|
+
|
33
|
+
|
34
|
+
def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
|
35
|
+
# Try to load the pretrained tokenizer
|
36
|
+
try:
|
37
|
+
CehrGptTokenizer.from_pretrained(os.path.abspath(tokenizer_name_or_path))
|
38
|
+
return True
|
39
|
+
except Exception:
|
40
|
+
LOG.info(f"The tokenizer does not exist at {tokenizer_name_or_path}")
|
41
|
+
return False
|
42
|
+
|
43
|
+
|
44
|
+
def load_and_create_tokenizer(
|
45
|
+
data_args: DataTrainingArguments,
|
46
|
+
model_args: ModelArguments,
|
47
|
+
cehrgpt_args: CehrGPTArguments,
|
48
|
+
dataset: Optional[Union[Dataset, DatasetDict]] = None,
|
49
|
+
) -> CehrGptTokenizer:
|
50
|
+
# Try to load the pretrained tokenizer
|
51
|
+
tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
|
52
|
+
try:
|
53
|
+
tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
|
54
|
+
except Exception as e:
|
55
|
+
LOG.warning(e)
|
56
|
+
if dataset is None:
|
57
|
+
raise RuntimeError(
|
58
|
+
f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
|
59
|
+
f"Tried to create the tokenizer, however the dataset is not provided."
|
60
|
+
)
|
61
|
+
tokenizer = CehrGptTokenizer.train_tokenizer(
|
62
|
+
dataset,
|
63
|
+
{},
|
64
|
+
data_args,
|
65
|
+
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
66
|
+
)
|
67
|
+
tokenizer.save_pretrained(tokenizer_abspath)
|
68
|
+
|
69
|
+
return tokenizer
|
70
|
+
|
71
|
+
|
72
|
+
def load_and_create_model(
|
73
|
+
model_args: ModelArguments,
|
74
|
+
cehrgpt_args: CehrGPTArguments,
|
75
|
+
training_args: TrainingArguments,
|
76
|
+
tokenizer: CehrGptTokenizer,
|
77
|
+
) -> CEHRGPT2LMHeadModel:
|
78
|
+
attn_implementation = (
|
79
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
80
|
+
)
|
81
|
+
torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
|
82
|
+
model_abspath = os.path.expanduser(model_args.model_name_or_path)
|
83
|
+
if cehrgpt_args.continue_pretrain:
|
84
|
+
try:
|
85
|
+
return CEHRGPT2LMHeadModel.from_pretrained(
|
86
|
+
model_abspath,
|
87
|
+
attn_implementation=attn_implementation,
|
88
|
+
torch_dtype=torch_dtype,
|
89
|
+
)
|
90
|
+
except Exception as e:
|
91
|
+
LOG.error(
|
92
|
+
f"When continue_pretrain is set to True, it assumes that CEHR-GPT has been trained "
|
93
|
+
f"and will be used to pretrain on new datasets. The CEHR-GPT checkpoint must exist at {model_abspath}"
|
94
|
+
)
|
95
|
+
raise e
|
96
|
+
try:
|
97
|
+
model_config = AutoConfig.from_pretrained(
|
98
|
+
model_abspath, attn_implementation=attn_implementation
|
99
|
+
)
|
100
|
+
except Exception as e:
|
101
|
+
LOG.warning(e)
|
102
|
+
if cehrgpt_args.causal_sfm:
|
103
|
+
model_args.max_position_embeddings += 1
|
104
|
+
if len(tokenizer.pretrained_token_ids) > 0:
|
105
|
+
pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
|
106
|
+
else:
|
107
|
+
pretrained_embedding_dim = model_args.hidden_size
|
108
|
+
model_config = CEHRGPTConfig(
|
109
|
+
vocab_size=tokenizer.vocab_size,
|
110
|
+
value_vocab_size=tokenizer.value_vocab_size,
|
111
|
+
time_token_vocab_size=tokenizer.time_token_vocab_size,
|
112
|
+
bos_token_id=tokenizer.end_token_id,
|
113
|
+
eos_token_id=tokenizer.end_token_id,
|
114
|
+
lab_token_ids=tokenizer.lab_token_ids,
|
115
|
+
token_to_time_token_mapping=tokenizer.token_to_time_token_mapping,
|
116
|
+
attn_implementation=attn_implementation,
|
117
|
+
causal_sfm=cehrgpt_args.causal_sfm,
|
118
|
+
demographics_size=cehrgpt_args.demographics_size,
|
119
|
+
lab_token_penalty=cehrgpt_args.lab_token_penalty,
|
120
|
+
lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
|
121
|
+
entropy_penalty=cehrgpt_args.entropy_penalty,
|
122
|
+
entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
|
123
|
+
n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
|
124
|
+
use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
|
125
|
+
pretrained_embedding_dim=pretrained_embedding_dim,
|
126
|
+
**model_args.as_dict(),
|
127
|
+
)
|
128
|
+
model = CEHRGPT2LMHeadModel(model_config)
|
129
|
+
if tokenizer.pretrained_token_ids:
|
130
|
+
model.cehrgpt.update_pretrained_embeddings(
|
131
|
+
tokenizer.pretrained_token_ids,
|
132
|
+
tokenizer.pretrained_embeddings,
|
133
|
+
)
|
134
|
+
if model.config.torch_dtype == torch.bfloat16:
|
135
|
+
return model.bfloat16()
|
136
|
+
elif model.config.torch_dtype == torch.float16:
|
137
|
+
return model.half()
|
138
|
+
return model
|
139
|
+
|
140
|
+
|
141
|
+
def main():
|
142
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
143
|
+
|
144
|
+
if data_args.streaming:
|
145
|
+
# This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
|
146
|
+
# This happens only when streaming is enabled
|
147
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
148
|
+
# The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0
|
149
|
+
# Otherwise the trainer will throw an error
|
150
|
+
training_args.dataloader_num_workers = 0
|
151
|
+
training_args.dataloader_prefetch_factor = 0
|
152
|
+
|
153
|
+
prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
|
154
|
+
if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
|
155
|
+
LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
|
156
|
+
processed_dataset = load_from_disk(data_args.data_folder)
|
157
|
+
# If the data has been processed in the past, it's assume the tokenizer has been created before.
|
158
|
+
# we load the CEHR-GPT tokenizer from the output folder, otherwise an exception will be raised.
|
159
|
+
tokenizer_name_or_path = os.path.expanduser(
|
160
|
+
training_args.output_dir
|
161
|
+
if cehrgpt_args.expand_tokenizer
|
162
|
+
else model_args.tokenizer_name_or_path
|
163
|
+
)
|
164
|
+
if not tokenizer_exists(tokenizer_name_or_path):
|
165
|
+
raise RuntimeError(
|
166
|
+
f"The dataset has been tokenized but the corresponding tokenizer: "
|
167
|
+
f"{model_args.tokenizer_name_or_path} does not exist"
|
168
|
+
)
|
169
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
170
|
+
elif any(prepared_ds_path.glob("*")):
|
171
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
172
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
173
|
+
LOG.info("Prepared dataset loaded from disk...")
|
174
|
+
# If the data has been processed in the past, it's assume the tokenizer has been created before.
|
175
|
+
# we load the CEHR-GPT tokenizer from the output folder, otherwise an exception will be raised.
|
176
|
+
tokenizer_name_or_path = os.path.expanduser(
|
177
|
+
training_args.output_dir
|
178
|
+
if cehrgpt_args.expand_tokenizer
|
179
|
+
else model_args.tokenizer_name_or_path
|
180
|
+
)
|
181
|
+
if not tokenizer_exists(tokenizer_name_or_path):
|
182
|
+
raise RuntimeError(
|
183
|
+
f"The dataset has been tokenized but the corresponding tokenizer: "
|
184
|
+
f"{model_args.tokenizer_name_or_path} does not exist"
|
185
|
+
)
|
186
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
187
|
+
else:
|
188
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
189
|
+
if data_args.is_data_in_meds:
|
190
|
+
meds_extension_path = get_meds_extension_path(
|
191
|
+
data_folder=data_args.data_folder,
|
192
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
193
|
+
)
|
194
|
+
try:
|
195
|
+
LOG.info(
|
196
|
+
"Trying to load the MEDS extension from disk at %s...",
|
197
|
+
meds_extension_path,
|
198
|
+
)
|
199
|
+
dataset = load_from_disk(meds_extension_path)
|
200
|
+
if data_args.streaming:
|
201
|
+
if isinstance(dataset, DatasetDict):
|
202
|
+
dataset = {
|
203
|
+
k: v.to_iterable_dataset(
|
204
|
+
num_shards=training_args.dataloader_num_workers
|
205
|
+
)
|
206
|
+
for k, v in dataset.items()
|
207
|
+
}
|
208
|
+
else:
|
209
|
+
dataset = dataset.to_iterable_dataset(
|
210
|
+
num_shards=training_args.dataloader_num_workers
|
211
|
+
)
|
212
|
+
except FileNotFoundError as e:
|
213
|
+
LOG.exception(e)
|
214
|
+
dataset = create_dataset_from_meds_reader(
|
215
|
+
data_args, is_pretraining=True
|
216
|
+
)
|
217
|
+
if not data_args.streaming:
|
218
|
+
dataset.save_to_disk(meds_extension_path)
|
219
|
+
else:
|
220
|
+
# Load the dataset from the parquet files
|
221
|
+
dataset = load_parquet_as_dataset(
|
222
|
+
data_args.data_folder, split="train", streaming=data_args.streaming
|
223
|
+
)
|
224
|
+
# If streaming is enabled, we need to manually split the data into train/val
|
225
|
+
if data_args.streaming and data_args.validation_split_num:
|
226
|
+
dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
|
227
|
+
train_set = dataset.skip(data_args.validation_split_num)
|
228
|
+
val_set = dataset.take(data_args.validation_split_num)
|
229
|
+
dataset = DatasetDict({"train": train_set, "test": val_set})
|
230
|
+
elif data_args.validation_split_percentage:
|
231
|
+
dataset = dataset.train_test_split(
|
232
|
+
test_size=data_args.validation_split_percentage,
|
233
|
+
seed=training_args.seed,
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
raise RuntimeError(
|
237
|
+
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
238
|
+
f"defined, otherwise validation_split_percentage needs to be provided. "
|
239
|
+
f"The current values are:\n"
|
240
|
+
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
241
|
+
f"validation_split_num: {data_args.validation_split_num}\n"
|
242
|
+
f"streaming: {data_args.streaming}"
|
243
|
+
)
|
244
|
+
|
245
|
+
# Create the CEHR-GPT tokenizer if it's not available in the output folder
|
246
|
+
cehrgpt_tokenizer = load_and_create_tokenizer(
|
247
|
+
data_args=data_args,
|
248
|
+
model_args=model_args,
|
249
|
+
cehrgpt_args=cehrgpt_args,
|
250
|
+
dataset=dataset,
|
251
|
+
)
|
252
|
+
# Retrain the tokenizer in case we want to pretrain the model further using different datasets
|
253
|
+
if cehrgpt_args.expand_tokenizer:
|
254
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
255
|
+
try:
|
256
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
257
|
+
except Exception:
|
258
|
+
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
259
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
260
|
+
dataset=dataset["train"],
|
261
|
+
data_args=data_args,
|
262
|
+
concept_name_mapping={},
|
263
|
+
pretrained_concept_embedding_model=PretrainedEmbeddings(
|
264
|
+
cehrgpt_args.pretrained_embedding_path
|
265
|
+
),
|
266
|
+
)
|
267
|
+
cehrgpt_tokenizer.save_pretrained(
|
268
|
+
os.path.expanduser(training_args.output_dir)
|
269
|
+
)
|
270
|
+
|
271
|
+
# sort the patient features chronologically and tokenize the data
|
272
|
+
processed_dataset = create_cehrgpt_pretraining_dataset(
|
273
|
+
dataset=dataset, cehrgpt_tokenizer=cehrgpt_tokenizer, data_args=data_args
|
274
|
+
)
|
275
|
+
# only save the data to the disk if it is not streaming
|
276
|
+
if not data_args.streaming:
|
277
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
278
|
+
|
279
|
+
def filter_func(examples):
|
280
|
+
if cehrgpt_args.drop_long_sequences:
|
281
|
+
return [
|
282
|
+
model_args.max_position_embeddings >= _ >= data_args.min_num_tokens
|
283
|
+
for _ in examples["num_of_concepts"]
|
284
|
+
]
|
285
|
+
else:
|
286
|
+
return [_ >= data_args.min_num_tokens for _ in examples["num_of_concepts"]]
|
287
|
+
|
288
|
+
# Create the args for batched filtering
|
289
|
+
filter_args = {"batched": True, "batch_size": data_args.preprocessing_batch_size}
|
290
|
+
# If the dataset is not in a streaming mode, we could add num_proc to enable parallelization
|
291
|
+
if not data_args.streaming:
|
292
|
+
filter_args["num_proc"] = data_args.preprocessing_num_workers
|
293
|
+
|
294
|
+
# The filter can't be applied to a DatasetDict of IterableDataset (in case of streaming)
|
295
|
+
# we need to iterate through all the datasets and apply the filter separately
|
296
|
+
if isinstance(processed_dataset, DatasetDict) or isinstance(
|
297
|
+
processed_dataset, IterableDatasetDict
|
298
|
+
):
|
299
|
+
for key in processed_dataset.keys():
|
300
|
+
processed_dataset[key] = processed_dataset[key].filter(
|
301
|
+
filter_func, **filter_args
|
302
|
+
)
|
303
|
+
else:
|
304
|
+
processed_dataset = processed_dataset.filter(filter_func, **filter_args)
|
305
|
+
|
306
|
+
model = load_and_create_model(
|
307
|
+
model_args, cehrgpt_args, training_args, cehrgpt_tokenizer
|
308
|
+
)
|
309
|
+
|
310
|
+
# Expand tokenizer to adapt to the new pretraining dataset
|
311
|
+
if model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
312
|
+
model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
313
|
+
# Update the pretrained embedding weights if they are available
|
314
|
+
if model.config.use_pretrained_embeddings:
|
315
|
+
model.cehrgpt.update_pretrained_embeddings(
|
316
|
+
cehrgpt_tokenizer.pretrained_token_ids,
|
317
|
+
cehrgpt_tokenizer.pretrained_embeddings,
|
318
|
+
)
|
319
|
+
elif cehrgpt_tokenizer.pretrained_token_ids:
|
320
|
+
model.config.pretrained_embedding_dim = (
|
321
|
+
cehrgpt_tokenizer.pretrained_embeddings.shape[1]
|
322
|
+
)
|
323
|
+
model.config.use_pretrained_embeddings = True
|
324
|
+
model.cehrgpt.initialize_pretrained_embeddings()
|
325
|
+
model.cehrgpt.update_pretrained_embeddings(
|
326
|
+
cehrgpt_tokenizer.pretrained_token_ids,
|
327
|
+
cehrgpt_tokenizer.pretrained_embeddings,
|
328
|
+
)
|
329
|
+
|
330
|
+
# Detecting last checkpoint.
|
331
|
+
last_checkpoint = get_last_hf_checkpoint(training_args)
|
332
|
+
|
333
|
+
# Set seed before initializing model.
|
334
|
+
set_seed(training_args.seed)
|
335
|
+
|
336
|
+
if not data_args.streaming:
|
337
|
+
processed_dataset.set_format("pt")
|
338
|
+
|
339
|
+
trainer = Trainer(
|
340
|
+
model=model,
|
341
|
+
data_collator=CehrGptDataCollator(
|
342
|
+
tokenizer=cehrgpt_tokenizer,
|
343
|
+
max_length=model_args.max_position_embeddings,
|
344
|
+
shuffle_records=data_args.shuffle_records,
|
345
|
+
include_ttv_prediction=model_args.include_ttv_prediction,
|
346
|
+
use_sub_time_tokenization=model_args.use_sub_time_tokenization,
|
347
|
+
include_values=model_args.include_values,
|
348
|
+
),
|
349
|
+
train_dataset=processed_dataset["train"],
|
350
|
+
eval_dataset=processed_dataset["test"],
|
351
|
+
args=training_args,
|
352
|
+
)
|
353
|
+
|
354
|
+
checkpoint = None
|
355
|
+
if training_args.resume_from_checkpoint is not None:
|
356
|
+
checkpoint = training_args.resume_from_checkpoint
|
357
|
+
elif last_checkpoint is not None:
|
358
|
+
checkpoint = last_checkpoint
|
359
|
+
|
360
|
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
361
|
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
362
|
+
metrics = train_result.metrics
|
363
|
+
|
364
|
+
trainer.log_metrics("train", metrics)
|
365
|
+
trainer.save_metrics("train", metrics)
|
366
|
+
trainer.save_state()
|
367
|
+
|
368
|
+
|
369
|
+
if __name__ == "__main__":
|
370
|
+
main()
|
@@ -0,0 +1,137 @@
|
|
1
|
+
import dataclasses
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
|
5
|
+
@dataclasses.dataclass
|
6
|
+
class CehrGPTArguments:
|
7
|
+
"""Arguments pertaining to what data we are going to input our model for training and eval."""
|
8
|
+
|
9
|
+
include_demographics: Optional[bool] = dataclasses.field(
|
10
|
+
default=False,
|
11
|
+
metadata={
|
12
|
+
"help": "A flag to indicate whether we want to always include the demographics for the long sequences that are longer than the model context window."
|
13
|
+
},
|
14
|
+
)
|
15
|
+
continue_pretrain: Optional[bool] = dataclasses.field(
|
16
|
+
default=False,
|
17
|
+
metadata={
|
18
|
+
"help": "A flag to indicate whether we want to continue to pretrain cehrgpt on the new dataset"
|
19
|
+
},
|
20
|
+
)
|
21
|
+
pretrained_embedding_path: Optional[str] = dataclasses.field(
|
22
|
+
default=None,
|
23
|
+
metadata={"help": "The path to the concept pretrained embeddings"},
|
24
|
+
)
|
25
|
+
retrain_with_full: Optional[bool] = dataclasses.field(
|
26
|
+
default=False,
|
27
|
+
metadata={
|
28
|
+
"help": "A flag to indicate whether we want to retrain the model on the full set after early stopping"
|
29
|
+
},
|
30
|
+
)
|
31
|
+
expand_tokenizer: Optional[bool] = dataclasses.field(
|
32
|
+
default=False,
|
33
|
+
metadata={
|
34
|
+
"help": "A flag to indicate whether we want to expand the tokenizer for fine-tuning."
|
35
|
+
},
|
36
|
+
)
|
37
|
+
few_shot_predict: Optional[bool] = dataclasses.field(
|
38
|
+
default=False,
|
39
|
+
metadata={
|
40
|
+
"help": "A flag to indicate whether we want to use a few shots to train the model"
|
41
|
+
},
|
42
|
+
)
|
43
|
+
n_shots: Optional[int] = dataclasses.field(
|
44
|
+
default=128,
|
45
|
+
metadata={"help": "The number of examples from the training set."},
|
46
|
+
)
|
47
|
+
hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
|
48
|
+
default=0.1,
|
49
|
+
metadata={
|
50
|
+
"help": "The percentage of the train/val will be use for hyperparameter tuning."
|
51
|
+
},
|
52
|
+
)
|
53
|
+
n_trials: Optional[int] = dataclasses.field(
|
54
|
+
default=10,
|
55
|
+
metadata={
|
56
|
+
"help": "The number of trails will be use for hyperparameter tuning."
|
57
|
+
},
|
58
|
+
)
|
59
|
+
hyperparameter_tuning: Optional[bool] = dataclasses.field(
|
60
|
+
default=False,
|
61
|
+
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
62
|
+
)
|
63
|
+
hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
|
64
|
+
default_factory=lambda: [4, 8, 16],
|
65
|
+
metadata={"help": "Hyperparameter search batch sizes"},
|
66
|
+
)
|
67
|
+
hyperparameter_num_train_epochs: Optional[List[int]] = dataclasses.field(
|
68
|
+
default_factory=lambda: [10],
|
69
|
+
metadata={"help": "Hyperparameter search num_train_epochs"},
|
70
|
+
)
|
71
|
+
lr_low: Optional[float] = dataclasses.field(
|
72
|
+
default=1e-5,
|
73
|
+
metadata={
|
74
|
+
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
75
|
+
},
|
76
|
+
)
|
77
|
+
lr_high: Optional[float] = dataclasses.field(
|
78
|
+
default=5e-5,
|
79
|
+
metadata={
|
80
|
+
"help": "The upper bound of the learning rate range for hyperparameter tuning."
|
81
|
+
},
|
82
|
+
)
|
83
|
+
weight_decays_low: Optional[float] = dataclasses.field(
|
84
|
+
default=1e-3,
|
85
|
+
metadata={
|
86
|
+
"help": "The lower bound of the weight decays range for hyperparameter tuning."
|
87
|
+
},
|
88
|
+
)
|
89
|
+
weight_decays_high: Optional[float] = dataclasses.field(
|
90
|
+
default=1e-2,
|
91
|
+
metadata={
|
92
|
+
"help": "The upper bound of the weight decays range for hyperparameter tuning."
|
93
|
+
},
|
94
|
+
)
|
95
|
+
causal_sfm: Optional[bool] = dataclasses.field(
|
96
|
+
default=False,
|
97
|
+
metadata={
|
98
|
+
"help": "A flag to indicate whether the GPT conforms to the causal Standard Fairness Model"
|
99
|
+
},
|
100
|
+
)
|
101
|
+
demographics_size: Optional[int] = dataclasses.field(
|
102
|
+
default=4,
|
103
|
+
metadata={
|
104
|
+
"help": "The number of demographics tokens in the patient sequence "
|
105
|
+
"It defaults to 4, assuming the demographics tokens follow this pattern [Year][Age][Gender][Race]"
|
106
|
+
},
|
107
|
+
)
|
108
|
+
drop_long_sequences: Optional[bool] = dataclasses.field(
|
109
|
+
default=False,
|
110
|
+
metadata={
|
111
|
+
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
112
|
+
},
|
113
|
+
)
|
114
|
+
lab_token_penalty: Optional[bool] = dataclasses.field(
|
115
|
+
default=False,
|
116
|
+
metadata={
|
117
|
+
"help": "A flag to indicate whether we want to use lab token loss penalty."
|
118
|
+
},
|
119
|
+
)
|
120
|
+
lab_token_loss_weight: Optional[float] = dataclasses.field(
|
121
|
+
default=1.0,
|
122
|
+
metadata={"help": "lab_token_loss_weight penalty co-efficient"},
|
123
|
+
)
|
124
|
+
entropy_penalty: Optional[bool] = dataclasses.field(
|
125
|
+
default=False,
|
126
|
+
metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
|
127
|
+
)
|
128
|
+
entropy_penalty_alpha: Optional[float] = dataclasses.field(
|
129
|
+
default=0.01,
|
130
|
+
metadata={"help": "Entropy penalty co-efficient"},
|
131
|
+
)
|
132
|
+
n_pretrained_embeddings_layers: Optional[int] = dataclasses.field(
|
133
|
+
default=2,
|
134
|
+
metadata={
|
135
|
+
"help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
|
136
|
+
},
|
137
|
+
)
|