cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,370 @@
1
+ import os
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from cehrbert.data_generators.hf_data_generator.meds_utils import (
6
+ create_dataset_from_meds_reader,
7
+ )
8
+ from cehrbert.runners.hf_runner_argument_dataclass import (
9
+ DataTrainingArguments,
10
+ ModelArguments,
11
+ )
12
+ from cehrbert.runners.runner_util import (
13
+ generate_prepared_ds_path,
14
+ get_last_hf_checkpoint,
15
+ get_meds_extension_path,
16
+ load_parquet_as_dataset,
17
+ )
18
+ from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
19
+ from transformers import AutoConfig, Trainer, TrainingArguments, set_seed
20
+ from transformers.utils import is_flash_attn_2_available, logging
21
+
22
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
23
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
24
+ from cehrgpt.models.config import CEHRGPTConfig
25
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
26
+ from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
27
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
28
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
29
+ from src.cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
30
+
31
+ LOG = logging.get_logger("transformers")
32
+
33
+
34
+ def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
35
+ # Try to load the pretrained tokenizer
36
+ try:
37
+ CehrGptTokenizer.from_pretrained(os.path.abspath(tokenizer_name_or_path))
38
+ return True
39
+ except Exception:
40
+ LOG.info(f"The tokenizer does not exist at {tokenizer_name_or_path}")
41
+ return False
42
+
43
+
44
+ def load_and_create_tokenizer(
45
+ data_args: DataTrainingArguments,
46
+ model_args: ModelArguments,
47
+ cehrgpt_args: CehrGPTArguments,
48
+ dataset: Optional[Union[Dataset, DatasetDict]] = None,
49
+ ) -> CehrGptTokenizer:
50
+ # Try to load the pretrained tokenizer
51
+ tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
52
+ try:
53
+ tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_abspath)
54
+ except Exception as e:
55
+ LOG.warning(e)
56
+ if dataset is None:
57
+ raise RuntimeError(
58
+ f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
59
+ f"Tried to create the tokenizer, however the dataset is not provided."
60
+ )
61
+ tokenizer = CehrGptTokenizer.train_tokenizer(
62
+ dataset,
63
+ {},
64
+ data_args,
65
+ PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
66
+ )
67
+ tokenizer.save_pretrained(tokenizer_abspath)
68
+
69
+ return tokenizer
70
+
71
+
72
+ def load_and_create_model(
73
+ model_args: ModelArguments,
74
+ cehrgpt_args: CehrGPTArguments,
75
+ training_args: TrainingArguments,
76
+ tokenizer: CehrGptTokenizer,
77
+ ) -> CEHRGPT2LMHeadModel:
78
+ attn_implementation = (
79
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
80
+ )
81
+ torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
82
+ model_abspath = os.path.expanduser(model_args.model_name_or_path)
83
+ if cehrgpt_args.continue_pretrain:
84
+ try:
85
+ return CEHRGPT2LMHeadModel.from_pretrained(
86
+ model_abspath,
87
+ attn_implementation=attn_implementation,
88
+ torch_dtype=torch_dtype,
89
+ )
90
+ except Exception as e:
91
+ LOG.error(
92
+ f"When continue_pretrain is set to True, it assumes that CEHR-GPT has been trained "
93
+ f"and will be used to pretrain on new datasets. The CEHR-GPT checkpoint must exist at {model_abspath}"
94
+ )
95
+ raise e
96
+ try:
97
+ model_config = AutoConfig.from_pretrained(
98
+ model_abspath, attn_implementation=attn_implementation
99
+ )
100
+ except Exception as e:
101
+ LOG.warning(e)
102
+ if cehrgpt_args.causal_sfm:
103
+ model_args.max_position_embeddings += 1
104
+ if len(tokenizer.pretrained_token_ids) > 0:
105
+ pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
106
+ else:
107
+ pretrained_embedding_dim = model_args.hidden_size
108
+ model_config = CEHRGPTConfig(
109
+ vocab_size=tokenizer.vocab_size,
110
+ value_vocab_size=tokenizer.value_vocab_size,
111
+ time_token_vocab_size=tokenizer.time_token_vocab_size,
112
+ bos_token_id=tokenizer.end_token_id,
113
+ eos_token_id=tokenizer.end_token_id,
114
+ lab_token_ids=tokenizer.lab_token_ids,
115
+ token_to_time_token_mapping=tokenizer.token_to_time_token_mapping,
116
+ attn_implementation=attn_implementation,
117
+ causal_sfm=cehrgpt_args.causal_sfm,
118
+ demographics_size=cehrgpt_args.demographics_size,
119
+ lab_token_penalty=cehrgpt_args.lab_token_penalty,
120
+ lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
121
+ entropy_penalty=cehrgpt_args.entropy_penalty,
122
+ entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
123
+ n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
124
+ use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
125
+ pretrained_embedding_dim=pretrained_embedding_dim,
126
+ **model_args.as_dict(),
127
+ )
128
+ model = CEHRGPT2LMHeadModel(model_config)
129
+ if tokenizer.pretrained_token_ids:
130
+ model.cehrgpt.update_pretrained_embeddings(
131
+ tokenizer.pretrained_token_ids,
132
+ tokenizer.pretrained_embeddings,
133
+ )
134
+ if model.config.torch_dtype == torch.bfloat16:
135
+ return model.bfloat16()
136
+ elif model.config.torch_dtype == torch.float16:
137
+ return model.half()
138
+ return model
139
+
140
+
141
+ def main():
142
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
143
+
144
+ if data_args.streaming:
145
+ # This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
146
+ # This happens only when streaming is enabled
147
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
148
+ # The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0
149
+ # Otherwise the trainer will throw an error
150
+ training_args.dataloader_num_workers = 0
151
+ training_args.dataloader_prefetch_factor = 0
152
+
153
+ prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
154
+ if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
155
+ LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
156
+ processed_dataset = load_from_disk(data_args.data_folder)
157
+ # If the data has been processed in the past, it's assume the tokenizer has been created before.
158
+ # we load the CEHR-GPT tokenizer from the output folder, otherwise an exception will be raised.
159
+ tokenizer_name_or_path = os.path.expanduser(
160
+ training_args.output_dir
161
+ if cehrgpt_args.expand_tokenizer
162
+ else model_args.tokenizer_name_or_path
163
+ )
164
+ if not tokenizer_exists(tokenizer_name_or_path):
165
+ raise RuntimeError(
166
+ f"The dataset has been tokenized but the corresponding tokenizer: "
167
+ f"{model_args.tokenizer_name_or_path} does not exist"
168
+ )
169
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
170
+ elif any(prepared_ds_path.glob("*")):
171
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
172
+ processed_dataset = load_from_disk(str(prepared_ds_path))
173
+ LOG.info("Prepared dataset loaded from disk...")
174
+ # If the data has been processed in the past, it's assume the tokenizer has been created before.
175
+ # we load the CEHR-GPT tokenizer from the output folder, otherwise an exception will be raised.
176
+ tokenizer_name_or_path = os.path.expanduser(
177
+ training_args.output_dir
178
+ if cehrgpt_args.expand_tokenizer
179
+ else model_args.tokenizer_name_or_path
180
+ )
181
+ if not tokenizer_exists(tokenizer_name_or_path):
182
+ raise RuntimeError(
183
+ f"The dataset has been tokenized but the corresponding tokenizer: "
184
+ f"{model_args.tokenizer_name_or_path} does not exist"
185
+ )
186
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
187
+ else:
188
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
189
+ if data_args.is_data_in_meds:
190
+ meds_extension_path = get_meds_extension_path(
191
+ data_folder=data_args.data_folder,
192
+ dataset_prepared_path=data_args.dataset_prepared_path,
193
+ )
194
+ try:
195
+ LOG.info(
196
+ "Trying to load the MEDS extension from disk at %s...",
197
+ meds_extension_path,
198
+ )
199
+ dataset = load_from_disk(meds_extension_path)
200
+ if data_args.streaming:
201
+ if isinstance(dataset, DatasetDict):
202
+ dataset = {
203
+ k: v.to_iterable_dataset(
204
+ num_shards=training_args.dataloader_num_workers
205
+ )
206
+ for k, v in dataset.items()
207
+ }
208
+ else:
209
+ dataset = dataset.to_iterable_dataset(
210
+ num_shards=training_args.dataloader_num_workers
211
+ )
212
+ except FileNotFoundError as e:
213
+ LOG.exception(e)
214
+ dataset = create_dataset_from_meds_reader(
215
+ data_args, is_pretraining=True
216
+ )
217
+ if not data_args.streaming:
218
+ dataset.save_to_disk(meds_extension_path)
219
+ else:
220
+ # Load the dataset from the parquet files
221
+ dataset = load_parquet_as_dataset(
222
+ data_args.data_folder, split="train", streaming=data_args.streaming
223
+ )
224
+ # If streaming is enabled, we need to manually split the data into train/val
225
+ if data_args.streaming and data_args.validation_split_num:
226
+ dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
227
+ train_set = dataset.skip(data_args.validation_split_num)
228
+ val_set = dataset.take(data_args.validation_split_num)
229
+ dataset = DatasetDict({"train": train_set, "test": val_set})
230
+ elif data_args.validation_split_percentage:
231
+ dataset = dataset.train_test_split(
232
+ test_size=data_args.validation_split_percentage,
233
+ seed=training_args.seed,
234
+ )
235
+ else:
236
+ raise RuntimeError(
237
+ f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
238
+ f"defined, otherwise validation_split_percentage needs to be provided. "
239
+ f"The current values are:\n"
240
+ f"validation_split_percentage: {data_args.validation_split_percentage}\n"
241
+ f"validation_split_num: {data_args.validation_split_num}\n"
242
+ f"streaming: {data_args.streaming}"
243
+ )
244
+
245
+ # Create the CEHR-GPT tokenizer if it's not available in the output folder
246
+ cehrgpt_tokenizer = load_and_create_tokenizer(
247
+ data_args=data_args,
248
+ model_args=model_args,
249
+ cehrgpt_args=cehrgpt_args,
250
+ dataset=dataset,
251
+ )
252
+ # Retrain the tokenizer in case we want to pretrain the model further using different datasets
253
+ if cehrgpt_args.expand_tokenizer:
254
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
255
+ try:
256
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
257
+ except Exception:
258
+ cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
259
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
260
+ dataset=dataset["train"],
261
+ data_args=data_args,
262
+ concept_name_mapping={},
263
+ pretrained_concept_embedding_model=PretrainedEmbeddings(
264
+ cehrgpt_args.pretrained_embedding_path
265
+ ),
266
+ )
267
+ cehrgpt_tokenizer.save_pretrained(
268
+ os.path.expanduser(training_args.output_dir)
269
+ )
270
+
271
+ # sort the patient features chronologically and tokenize the data
272
+ processed_dataset = create_cehrgpt_pretraining_dataset(
273
+ dataset=dataset, cehrgpt_tokenizer=cehrgpt_tokenizer, data_args=data_args
274
+ )
275
+ # only save the data to the disk if it is not streaming
276
+ if not data_args.streaming:
277
+ processed_dataset.save_to_disk(prepared_ds_path)
278
+
279
+ def filter_func(examples):
280
+ if cehrgpt_args.drop_long_sequences:
281
+ return [
282
+ model_args.max_position_embeddings >= _ >= data_args.min_num_tokens
283
+ for _ in examples["num_of_concepts"]
284
+ ]
285
+ else:
286
+ return [_ >= data_args.min_num_tokens for _ in examples["num_of_concepts"]]
287
+
288
+ # Create the args for batched filtering
289
+ filter_args = {"batched": True, "batch_size": data_args.preprocessing_batch_size}
290
+ # If the dataset is not in a streaming mode, we could add num_proc to enable parallelization
291
+ if not data_args.streaming:
292
+ filter_args["num_proc"] = data_args.preprocessing_num_workers
293
+
294
+ # The filter can't be applied to a DatasetDict of IterableDataset (in case of streaming)
295
+ # we need to iterate through all the datasets and apply the filter separately
296
+ if isinstance(processed_dataset, DatasetDict) or isinstance(
297
+ processed_dataset, IterableDatasetDict
298
+ ):
299
+ for key in processed_dataset.keys():
300
+ processed_dataset[key] = processed_dataset[key].filter(
301
+ filter_func, **filter_args
302
+ )
303
+ else:
304
+ processed_dataset = processed_dataset.filter(filter_func, **filter_args)
305
+
306
+ model = load_and_create_model(
307
+ model_args, cehrgpt_args, training_args, cehrgpt_tokenizer
308
+ )
309
+
310
+ # Expand tokenizer to adapt to the new pretraining dataset
311
+ if model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
312
+ model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
313
+ # Update the pretrained embedding weights if they are available
314
+ if model.config.use_pretrained_embeddings:
315
+ model.cehrgpt.update_pretrained_embeddings(
316
+ cehrgpt_tokenizer.pretrained_token_ids,
317
+ cehrgpt_tokenizer.pretrained_embeddings,
318
+ )
319
+ elif cehrgpt_tokenizer.pretrained_token_ids:
320
+ model.config.pretrained_embedding_dim = (
321
+ cehrgpt_tokenizer.pretrained_embeddings.shape[1]
322
+ )
323
+ model.config.use_pretrained_embeddings = True
324
+ model.cehrgpt.initialize_pretrained_embeddings()
325
+ model.cehrgpt.update_pretrained_embeddings(
326
+ cehrgpt_tokenizer.pretrained_token_ids,
327
+ cehrgpt_tokenizer.pretrained_embeddings,
328
+ )
329
+
330
+ # Detecting last checkpoint.
331
+ last_checkpoint = get_last_hf_checkpoint(training_args)
332
+
333
+ # Set seed before initializing model.
334
+ set_seed(training_args.seed)
335
+
336
+ if not data_args.streaming:
337
+ processed_dataset.set_format("pt")
338
+
339
+ trainer = Trainer(
340
+ model=model,
341
+ data_collator=CehrGptDataCollator(
342
+ tokenizer=cehrgpt_tokenizer,
343
+ max_length=model_args.max_position_embeddings,
344
+ shuffle_records=data_args.shuffle_records,
345
+ include_ttv_prediction=model_args.include_ttv_prediction,
346
+ use_sub_time_tokenization=model_args.use_sub_time_tokenization,
347
+ include_values=model_args.include_values,
348
+ ),
349
+ train_dataset=processed_dataset["train"],
350
+ eval_dataset=processed_dataset["test"],
351
+ args=training_args,
352
+ )
353
+
354
+ checkpoint = None
355
+ if training_args.resume_from_checkpoint is not None:
356
+ checkpoint = training_args.resume_from_checkpoint
357
+ elif last_checkpoint is not None:
358
+ checkpoint = last_checkpoint
359
+
360
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
361
+ trainer.save_model() # Saves the tokenizer too for easy upload
362
+ metrics = train_result.metrics
363
+
364
+ trainer.log_metrics("train", metrics)
365
+ trainer.save_metrics("train", metrics)
366
+ trainer.save_state()
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
@@ -0,0 +1,137 @@
1
+ import dataclasses
2
+ from typing import List, Optional
3
+
4
+
5
+ @dataclasses.dataclass
6
+ class CehrGPTArguments:
7
+ """Arguments pertaining to what data we are going to input our model for training and eval."""
8
+
9
+ include_demographics: Optional[bool] = dataclasses.field(
10
+ default=False,
11
+ metadata={
12
+ "help": "A flag to indicate whether we want to always include the demographics for the long sequences that are longer than the model context window."
13
+ },
14
+ )
15
+ continue_pretrain: Optional[bool] = dataclasses.field(
16
+ default=False,
17
+ metadata={
18
+ "help": "A flag to indicate whether we want to continue to pretrain cehrgpt on the new dataset"
19
+ },
20
+ )
21
+ pretrained_embedding_path: Optional[str] = dataclasses.field(
22
+ default=None,
23
+ metadata={"help": "The path to the concept pretrained embeddings"},
24
+ )
25
+ retrain_with_full: Optional[bool] = dataclasses.field(
26
+ default=False,
27
+ metadata={
28
+ "help": "A flag to indicate whether we want to retrain the model on the full set after early stopping"
29
+ },
30
+ )
31
+ expand_tokenizer: Optional[bool] = dataclasses.field(
32
+ default=False,
33
+ metadata={
34
+ "help": "A flag to indicate whether we want to expand the tokenizer for fine-tuning."
35
+ },
36
+ )
37
+ few_shot_predict: Optional[bool] = dataclasses.field(
38
+ default=False,
39
+ metadata={
40
+ "help": "A flag to indicate whether we want to use a few shots to train the model"
41
+ },
42
+ )
43
+ n_shots: Optional[int] = dataclasses.field(
44
+ default=128,
45
+ metadata={"help": "The number of examples from the training set."},
46
+ )
47
+ hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
48
+ default=0.1,
49
+ metadata={
50
+ "help": "The percentage of the train/val will be use for hyperparameter tuning."
51
+ },
52
+ )
53
+ n_trials: Optional[int] = dataclasses.field(
54
+ default=10,
55
+ metadata={
56
+ "help": "The number of trails will be use for hyperparameter tuning."
57
+ },
58
+ )
59
+ hyperparameter_tuning: Optional[bool] = dataclasses.field(
60
+ default=False,
61
+ metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
62
+ )
63
+ hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
64
+ default_factory=lambda: [4, 8, 16],
65
+ metadata={"help": "Hyperparameter search batch sizes"},
66
+ )
67
+ hyperparameter_num_train_epochs: Optional[List[int]] = dataclasses.field(
68
+ default_factory=lambda: [10],
69
+ metadata={"help": "Hyperparameter search num_train_epochs"},
70
+ )
71
+ lr_low: Optional[float] = dataclasses.field(
72
+ default=1e-5,
73
+ metadata={
74
+ "help": "The lower bound of the learning rate range for hyperparameter tuning."
75
+ },
76
+ )
77
+ lr_high: Optional[float] = dataclasses.field(
78
+ default=5e-5,
79
+ metadata={
80
+ "help": "The upper bound of the learning rate range for hyperparameter tuning."
81
+ },
82
+ )
83
+ weight_decays_low: Optional[float] = dataclasses.field(
84
+ default=1e-3,
85
+ metadata={
86
+ "help": "The lower bound of the weight decays range for hyperparameter tuning."
87
+ },
88
+ )
89
+ weight_decays_high: Optional[float] = dataclasses.field(
90
+ default=1e-2,
91
+ metadata={
92
+ "help": "The upper bound of the weight decays range for hyperparameter tuning."
93
+ },
94
+ )
95
+ causal_sfm: Optional[bool] = dataclasses.field(
96
+ default=False,
97
+ metadata={
98
+ "help": "A flag to indicate whether the GPT conforms to the causal Standard Fairness Model"
99
+ },
100
+ )
101
+ demographics_size: Optional[int] = dataclasses.field(
102
+ default=4,
103
+ metadata={
104
+ "help": "The number of demographics tokens in the patient sequence "
105
+ "It defaults to 4, assuming the demographics tokens follow this pattern [Year][Age][Gender][Race]"
106
+ },
107
+ )
108
+ drop_long_sequences: Optional[bool] = dataclasses.field(
109
+ default=False,
110
+ metadata={
111
+ "help": "The lower bound of the learning rate range for hyperparameter tuning."
112
+ },
113
+ )
114
+ lab_token_penalty: Optional[bool] = dataclasses.field(
115
+ default=False,
116
+ metadata={
117
+ "help": "A flag to indicate whether we want to use lab token loss penalty."
118
+ },
119
+ )
120
+ lab_token_loss_weight: Optional[float] = dataclasses.field(
121
+ default=1.0,
122
+ metadata={"help": "lab_token_loss_weight penalty co-efficient"},
123
+ )
124
+ entropy_penalty: Optional[bool] = dataclasses.field(
125
+ default=False,
126
+ metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
127
+ )
128
+ entropy_penalty_alpha: Optional[float] = dataclasses.field(
129
+ default=0.01,
130
+ metadata={"help": "Entropy penalty co-efficient"},
131
+ )
132
+ n_pretrained_embeddings_layers: Optional[int] = dataclasses.field(
133
+ default=2,
134
+ metadata={
135
+ "help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
136
+ },
137
+ )