cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/generation/omop_converter_batch.py +3 -0
  7. cehrgpt/models/config.py +10 -0
  8. cehrgpt/models/hf_cehrgpt.py +244 -73
  9. cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
  10. cehrgpt/runners/data_utils.py +243 -0
  11. cehrgpt/runners/gpt_runner_util.py +0 -10
  12. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
  13. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
  14. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
  15. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  16. cehrgpt/runners/sample_packing_trainer.py +168 -0
  17. cehrgpt/simulations/__init__.py +0 -0
  18. cehrgpt/simulations/generate_plots.py +95 -0
  19. cehrgpt/simulations/run_simulation.sh +24 -0
  20. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  21. cehrgpt/simulations/time_token_simulation.py +177 -0
  22. cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
  23. cehrgpt/tools/linear_prob/__init__.py +0 -0
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  25. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  26. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
  27. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
  28. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  29. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  30. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,12 @@
1
1
  import os
2
+ from functools import partial
2
3
  from typing import Optional, Union
3
4
 
5
+ import numpy as np
4
6
  import torch
7
+ import torch.distributed as dist
5
8
  from cehrbert.data_generators.hf_data_generator.meds_utils import (
9
+ CacheFileCollector,
6
10
  create_dataset_from_meds_reader,
7
11
  )
8
12
  from cehrbert.runners.hf_runner_argument_dataclass import (
@@ -16,21 +20,41 @@ from cehrbert.runners.runner_util import (
16
20
  load_parquet_as_dataset,
17
21
  )
18
22
  from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
19
- from transformers import AutoConfig, Trainer, TrainingArguments, set_seed
23
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
24
+ from transformers.trainer_utils import is_main_process
20
25
  from transformers.utils import is_flash_attn_2_available, logging
21
26
 
22
27
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
23
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
28
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
29
+ CehrGptDataCollator,
30
+ SamplePackingCehrGptDataCollator,
31
+ )
32
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
24
33
  from cehrgpt.models.config import CEHRGPTConfig
25
34
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
26
35
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
27
36
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
28
37
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
29
- from src.cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
38
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
39
+ from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
30
40
 
31
41
  LOG = logging.get_logger("transformers")
32
42
 
33
43
 
44
+ class CustomEarlyStoppingCallback(EarlyStoppingCallback):
45
+ def check_metric_value(self, args, state, control, metric_value):
46
+ # best_metric is set by code for load_best_model
47
+ operator = np.greater if args.greater_is_better else np.less
48
+ if state.best_metric is None or (
49
+ operator(metric_value, state.best_metric)
50
+ and abs(metric_value - state.best_metric) / state.best_metric
51
+ > self.early_stopping_threshold
52
+ ):
53
+ self.early_stopping_patience_counter = 0
54
+ else:
55
+ self.early_stopping_patience_counter += 1
56
+
57
+
34
58
  def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
35
59
  # Try to load the pretrained tokenizer
36
60
  try:
@@ -58,13 +82,16 @@ def load_and_create_tokenizer(
58
82
  f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
59
83
  f"Tried to create the tokenizer, however the dataset is not provided."
60
84
  )
85
+ LOG.info("Started training the tokenizer ...")
61
86
  tokenizer = CehrGptTokenizer.train_tokenizer(
62
87
  dataset,
63
88
  {},
64
89
  data_args,
65
90
  PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
66
91
  )
92
+ LOG.info("Finished training the tokenizer ...")
67
93
  tokenizer.save_pretrained(tokenizer_abspath)
94
+ LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
68
95
 
69
96
  return tokenizer
70
97
 
@@ -82,11 +109,25 @@ def load_and_create_model(
82
109
  model_abspath = os.path.expanduser(model_args.model_name_or_path)
83
110
  if cehrgpt_args.continue_pretrain:
84
111
  try:
85
- return CEHRGPT2LMHeadModel.from_pretrained(
112
+ pretrained_model = CEHRGPT2LMHeadModel.from_pretrained(
86
113
  model_abspath,
87
114
  attn_implementation=attn_implementation,
88
115
  torch_dtype=torch_dtype,
89
116
  )
117
+ if (
118
+ pretrained_model.config.max_position_embeddings
119
+ < model_args.max_position_embeddings
120
+ ):
121
+ LOG.info(
122
+ f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
123
+ )
124
+ pretrained_model.config.max_position_embeddings = (
125
+ model_args.max_position_embeddings
126
+ )
127
+ pretrained_model.resize_position_embeddings(
128
+ model_args.max_position_embeddings
129
+ )
130
+ return pretrained_model
90
131
  except Exception as e:
91
132
  LOG.error(
92
133
  f"When continue_pretrain is set to True, it assumes that CEHR-GPT has been trained "
@@ -94,7 +135,7 @@ def load_and_create_model(
94
135
  )
95
136
  raise e
96
137
  try:
97
- model_config = AutoConfig.from_pretrained(
138
+ model_config = CEHRGPTConfig.from_pretrained(
98
139
  model_abspath, attn_implementation=attn_implementation
99
140
  )
100
141
  except Exception as e:
@@ -105,6 +146,7 @@ def load_and_create_model(
105
146
  pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
106
147
  else:
107
148
  pretrained_embedding_dim = model_args.hidden_size
149
+
108
150
  model_config = CEHRGPTConfig(
109
151
  vocab_size=tokenizer.vocab_size,
110
152
  value_vocab_size=tokenizer.value_vocab_size,
@@ -116,15 +158,23 @@ def load_and_create_model(
116
158
  attn_implementation=attn_implementation,
117
159
  causal_sfm=cehrgpt_args.causal_sfm,
118
160
  demographics_size=cehrgpt_args.demographics_size,
161
+ next_token_prediction_loss_weight=cehrgpt_args.next_token_prediction_loss_weight,
119
162
  lab_token_penalty=cehrgpt_args.lab_token_penalty,
120
163
  lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
164
+ value_prediction_loss_weight=cehrgpt_args.value_prediction_loss_weight,
121
165
  entropy_penalty=cehrgpt_args.entropy_penalty,
122
166
  entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
123
167
  n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
124
168
  use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
125
169
  pretrained_embedding_dim=pretrained_embedding_dim,
170
+ sample_packing_max_positions=(
171
+ cehrgpt_args.max_tokens_per_batch
172
+ if cehrgpt_args.sample_packing
173
+ else model_args.max_position_embeddings
174
+ ),
126
175
  **model_args.as_dict(),
127
176
  )
177
+
128
178
  model = CEHRGPT2LMHeadModel(model_config)
129
179
  if tokenizer.pretrained_token_ids:
130
180
  model.cehrgpt.update_pretrained_embeddings(
@@ -141,6 +191,11 @@ def load_and_create_model(
141
191
  def main():
142
192
  cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
143
193
 
194
+ if cehrgpt_args.sample_packing and data_args.streaming:
195
+ raise RuntimeError(
196
+ f"sample_packing is not supported when streaming is enabled, please set streaming to False"
197
+ )
198
+
144
199
  if data_args.streaming:
145
200
  # This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
146
201
  # This happens only when streaming is enabled
@@ -148,8 +203,10 @@ def main():
148
203
  # The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0
149
204
  # Otherwise the trainer will throw an error
150
205
  training_args.dataloader_num_workers = 0
151
- training_args.dataloader_prefetch_factor = 0
206
+ training_args.dataloader_prefetch_factor = None
152
207
 
208
+ processed_dataset: Optional[DatasetDict] = None
209
+ cache_file_collector = CacheFileCollector()
153
210
  prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
154
211
  if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
155
212
  LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
@@ -185,96 +242,157 @@ def main():
185
242
  )
186
243
  cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
187
244
  else:
188
- # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
189
- if data_args.is_data_in_meds:
190
- meds_extension_path = get_meds_extension_path(
191
- data_folder=data_args.data_folder,
192
- dataset_prepared_path=data_args.dataset_prepared_path,
193
- )
194
- try:
195
- LOG.info(
196
- "Trying to load the MEDS extension from disk at %s...",
197
- meds_extension_path,
245
+ # Only run tokenization and data transformation in the main process in torch distributed training
246
+ # otherwise the multiple processes will create tokenizers at the same time
247
+ if is_main_process(training_args.local_rank):
248
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
249
+ if data_args.is_data_in_meds:
250
+ meds_extension_path = get_meds_extension_path(
251
+ data_folder=data_args.data_folder,
252
+ dataset_prepared_path=data_args.dataset_prepared_path,
198
253
  )
199
- dataset = load_from_disk(meds_extension_path)
200
- if data_args.streaming:
201
- if isinstance(dataset, DatasetDict):
202
- dataset = {
203
- k: v.to_iterable_dataset(
254
+ try:
255
+ LOG.info(
256
+ "Trying to load the MEDS extension from disk at %s...",
257
+ meds_extension_path,
258
+ )
259
+ dataset = load_from_disk(meds_extension_path)
260
+ if data_args.streaming:
261
+ if isinstance(dataset, DatasetDict):
262
+ dataset = {
263
+ k: v.to_iterable_dataset(
264
+ num_shards=training_args.dataloader_num_workers
265
+ )
266
+ for k, v in dataset.items()
267
+ }
268
+ else:
269
+ dataset = dataset.to_iterable_dataset(
204
270
  num_shards=training_args.dataloader_num_workers
205
271
  )
206
- for k, v in dataset.items()
207
- }
208
- else:
209
- dataset = dataset.to_iterable_dataset(
210
- num_shards=training_args.dataloader_num_workers
272
+ except FileNotFoundError as e:
273
+ LOG.warning(e)
274
+ dataset = create_dataset_from_meds_reader(
275
+ data_args=data_args,
276
+ dataset_mappings=[
277
+ MedToCehrGPTDatasetMapping(
278
+ data_args=data_args,
279
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
280
+ )
281
+ ],
282
+ cache_file_collector=cache_file_collector,
283
+ )
284
+ if not data_args.streaming:
285
+ dataset.save_to_disk(str(meds_extension_path))
286
+ stats = dataset.cleanup_cache_files()
287
+ LOG.info(
288
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
289
+ stats,
211
290
  )
212
- except FileNotFoundError as e:
213
- LOG.exception(e)
214
- dataset = create_dataset_from_meds_reader(
215
- data_args, is_pretraining=True
216
- )
217
- if not data_args.streaming:
218
- dataset.save_to_disk(meds_extension_path)
219
- else:
220
- # Load the dataset from the parquet files
221
- dataset = load_parquet_as_dataset(
222
- data_args.data_folder, split="train", streaming=data_args.streaming
223
- )
224
- # If streaming is enabled, we need to manually split the data into train/val
225
- if data_args.streaming and data_args.validation_split_num:
226
- dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
227
- train_set = dataset.skip(data_args.validation_split_num)
228
- val_set = dataset.take(data_args.validation_split_num)
229
- dataset = DatasetDict({"train": train_set, "test": val_set})
230
- elif data_args.validation_split_percentage:
231
- dataset = dataset.train_test_split(
232
- test_size=data_args.validation_split_percentage,
233
- seed=training_args.seed,
234
- )
291
+ # Clean up the files created from the data generator
292
+ cache_file_collector.remove_cache_files()
293
+ dataset = load_from_disk(str(meds_extension_path))
235
294
  else:
236
- raise RuntimeError(
237
- f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
238
- f"defined, otherwise validation_split_percentage needs to be provided. "
239
- f"The current values are:\n"
240
- f"validation_split_percentage: {data_args.validation_split_percentage}\n"
241
- f"validation_split_num: {data_args.validation_split_num}\n"
242
- f"streaming: {data_args.streaming}"
295
+ # Load the dataset from the parquet files
296
+ dataset = load_parquet_as_dataset(
297
+ os.path.expanduser(data_args.data_folder),
298
+ split="train",
299
+ streaming=data_args.streaming,
243
300
  )
301
+ # If streaming is enabled, we need to manually split the data into train/val
302
+ if data_args.streaming and data_args.validation_split_num:
303
+ dataset = dataset.shuffle(
304
+ buffer_size=10_000, seed=training_args.seed
305
+ )
306
+ train_set = dataset.skip(data_args.validation_split_num)
307
+ val_set = dataset.take(data_args.validation_split_num)
308
+ dataset = DatasetDict({"train": train_set, "validation": val_set})
309
+ elif data_args.validation_split_percentage:
310
+ dataset = dataset.train_test_split(
311
+ test_size=data_args.validation_split_percentage,
312
+ seed=training_args.seed,
313
+ )
314
+ dataset = DatasetDict(
315
+ {"train": dataset["train"], "validation": dataset["test"]}
316
+ )
317
+ else:
318
+ raise RuntimeError(
319
+ f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
320
+ f"defined, otherwise validation_split_percentage needs to be provided. "
321
+ f"The current values are:\n"
322
+ f"validation_split_percentage: {data_args.validation_split_percentage}\n"
323
+ f"validation_split_num: {data_args.validation_split_num}\n"
324
+ f"streaming: {data_args.streaming}"
325
+ )
244
326
 
245
- # Create the CEHR-GPT tokenizer if it's not available in the output folder
246
- cehrgpt_tokenizer = load_and_create_tokenizer(
247
- data_args=data_args,
248
- model_args=model_args,
249
- cehrgpt_args=cehrgpt_args,
250
- dataset=dataset,
251
- )
252
- # Retrain the tokenizer in case we want to pretrain the model further using different datasets
253
- if cehrgpt_args.expand_tokenizer:
254
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
255
- try:
256
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
257
- except Exception:
258
- cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
259
- cehrgpt_tokenizer=cehrgpt_tokenizer,
260
- dataset=dataset["train"],
261
- data_args=data_args,
262
- concept_name_mapping={},
263
- pretrained_concept_embedding_model=PretrainedEmbeddings(
264
- cehrgpt_args.pretrained_embedding_path
265
- ),
266
- )
267
- cehrgpt_tokenizer.save_pretrained(
268
- os.path.expanduser(training_args.output_dir)
327
+ # Create the CEHR-GPT tokenizer if it's not available in the output folder
328
+ cehrgpt_tokenizer = load_and_create_tokenizer(
329
+ data_args=data_args,
330
+ model_args=model_args,
331
+ cehrgpt_args=cehrgpt_args,
332
+ dataset=dataset,
333
+ )
334
+
335
+ # Retrain the tokenizer in case we want to pretrain the model further using different datasets
336
+ if cehrgpt_args.expand_tokenizer:
337
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
338
+ try:
339
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
340
+ new_tokenizer_path
341
+ )
342
+ except Exception:
343
+ cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
344
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
345
+ dataset=dataset["train"],
346
+ data_args=data_args,
347
+ concept_name_mapping={},
348
+ pretrained_concept_embedding_model=PretrainedEmbeddings(
349
+ cehrgpt_args.pretrained_embedding_path
350
+ ),
351
+ )
352
+ cehrgpt_tokenizer.save_pretrained(
353
+ os.path.expanduser(training_args.output_dir)
354
+ )
355
+
356
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
357
+ if not data_args.streaming:
358
+ all_columns = dataset["train"].column_names
359
+ if "visit_concept_ids" in all_columns:
360
+ dataset = dataset.remove_columns(["visit_concept_ids"])
361
+
362
+ # sort the patient features chronologically and tokenize the data
363
+ processed_dataset = create_cehrgpt_pretraining_dataset(
364
+ dataset=dataset,
365
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
366
+ data_args=data_args,
367
+ cache_file_collector=cache_file_collector,
368
+ )
369
+ # only save the data to the disk if it is not streaming
370
+ if not data_args.streaming:
371
+ processed_dataset.save_to_disk(str(prepared_ds_path))
372
+ stats = processed_dataset.cleanup_cache_files()
373
+ LOG.info(
374
+ "Clean up the cached files for the cehrgpt pretraining dataset: %s",
375
+ stats,
269
376
  )
377
+ cache_file_collector.remove_cache_files()
378
+
379
+ # After main-process-only operations, synchronize all processes to ensure consistency
380
+ if dist.is_available() and dist.is_initialized():
381
+ dist.barrier()
270
382
 
271
- # sort the patient features chronologically and tokenize the data
272
- processed_dataset = create_cehrgpt_pretraining_dataset(
273
- dataset=dataset, cehrgpt_tokenizer=cehrgpt_tokenizer, data_args=data_args
383
+ # Loading tokenizer in all processes in torch distributed training
384
+ tokenizer_name_or_path = os.path.expanduser(
385
+ training_args.output_dir
386
+ if cehrgpt_args.expand_tokenizer
387
+ else model_args.tokenizer_name_or_path
274
388
  )
275
- # only save the data to the disk if it is not streaming
389
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
390
+ # Load the dataset from disk again to in torch distributed training
276
391
  if not data_args.streaming:
277
- processed_dataset.save_to_disk(prepared_ds_path)
392
+ processed_dataset = load_from_disk(str(prepared_ds_path))
393
+
394
+ if processed_dataset is None:
395
+ raise RuntimeError("The processed dataset cannot be None")
278
396
 
279
397
  def filter_func(examples):
280
398
  if cehrgpt_args.drop_long_sequences:
@@ -333,22 +451,64 @@ def main():
333
451
  # Set seed before initializing model.
334
452
  set_seed(training_args.seed)
335
453
 
336
- if not data_args.streaming:
454
+ if not data_args.streaming and not cehrgpt_args.sample_packing:
337
455
  processed_dataset.set_format("pt")
338
456
 
339
- trainer = Trainer(
457
+ callbacks = []
458
+ if cehrgpt_args.use_early_stopping:
459
+ callbacks.append(
460
+ CustomEarlyStoppingCallback(
461
+ model_args.early_stopping_patience,
462
+ cehrgpt_args.early_stopping_threshold,
463
+ )
464
+ )
465
+
466
+ if cehrgpt_args.sample_packing:
467
+ trainer_class = partial(
468
+ SamplePackingTrainer,
469
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
470
+ max_position_embeddings=model_args.max_position_embeddings,
471
+ train_lengths=processed_dataset["train"]["num_of_concepts"],
472
+ validation_lengths=(
473
+ processed_dataset["validation"]
474
+ if "validation" in processed_dataset
475
+ else processed_dataset["test"]
476
+ )["num_of_concepts"],
477
+ )
478
+ training_args.per_device_train_batch_size = 1
479
+ training_args.per_device_eval_batch_size = 1
480
+ data_collator_fn = partial(
481
+ SamplePackingCehrGptDataCollator,
482
+ cehrgpt_args.max_tokens_per_batch,
483
+ model_args.max_position_embeddings,
484
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
485
+ )
486
+ else:
487
+ trainer_class = Trainer
488
+ data_collator_fn = CehrGptDataCollator
489
+
490
+ trainer = trainer_class(
340
491
  model=model,
341
- data_collator=CehrGptDataCollator(
492
+ data_collator=data_collator_fn(
342
493
  tokenizer=cehrgpt_tokenizer,
343
- max_length=model_args.max_position_embeddings,
494
+ max_length=(
495
+ cehrgpt_args.max_tokens_per_batch
496
+ if cehrgpt_args.sample_packing
497
+ else model_args.max_position_embeddings
498
+ ),
344
499
  shuffle_records=data_args.shuffle_records,
345
500
  include_ttv_prediction=model_args.include_ttv_prediction,
346
501
  use_sub_time_tokenization=model_args.use_sub_time_tokenization,
347
502
  include_values=model_args.include_values,
348
503
  ),
349
504
  train_dataset=processed_dataset["train"],
350
- eval_dataset=processed_dataset["test"],
505
+ eval_dataset=(
506
+ processed_dataset["validation"]
507
+ if "validation" in processed_dataset
508
+ else processed_dataset["test"]
509
+ ),
351
510
  args=training_args,
511
+ callbacks=callbacks,
352
512
  )
353
513
 
354
514
  checkpoint = None
@@ -6,6 +6,10 @@ from typing import List, Optional
6
6
  class CehrGPTArguments:
7
7
  """Arguments pertaining to what data we are going to input our model for training and eval."""
8
8
 
9
+ include_inpatient_hour_token: Optional[bool] = dataclasses.field(
10
+ default=True,
11
+ metadata={"help": "Include inpatient hour token"},
12
+ )
9
13
  include_demographics: Optional[bool] = dataclasses.field(
10
14
  default=False,
11
15
  metadata={
@@ -111,6 +115,9 @@ class CehrGPTArguments:
111
115
  "help": "The lower bound of the learning rate range for hyperparameter tuning."
112
116
  },
113
117
  )
118
+ next_token_prediction_loss_weight: float = dataclasses.field(
119
+ default=1.0, metadata={"help": "The weight of the next token prediction loss"}
120
+ )
114
121
  lab_token_penalty: Optional[bool] = dataclasses.field(
115
122
  default=False,
116
123
  metadata={
@@ -121,6 +128,10 @@ class CehrGPTArguments:
121
128
  default=1.0,
122
129
  metadata={"help": "lab_token_loss_weight penalty co-efficient"},
123
130
  )
131
+ value_prediction_loss_weight: Optional[float] = dataclasses.field(
132
+ default=1.0,
133
+ metadata={"help": "The weight of the value prediction loss"},
134
+ )
124
135
  entropy_penalty: Optional[bool] = dataclasses.field(
125
136
  default=False,
126
137
  metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
@@ -135,3 +146,38 @@ class CehrGPTArguments:
135
146
  "help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
136
147
  },
137
148
  )
149
+ meds_repartition: Optional[bool] = dataclasses.field(
150
+ default=False,
151
+ metadata={
152
+ "help": "A flag to indicate whether we want to repartition the meds train tune sets"
153
+ },
154
+ )
155
+ use_early_stopping: Optional[bool] = dataclasses.field(
156
+ default=True,
157
+ metadata={"help": "A flag to indicate whether we want to use early stopping."},
158
+ )
159
+ early_stopping_threshold: Optional[float] = dataclasses.field(
160
+ default=0.01,
161
+ metadata={
162
+ "help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
163
+ },
164
+ )
165
+ sample_packing: Optional[bool] = dataclasses.field(
166
+ default=False,
167
+ metadata={
168
+ "help": "A flag to indicate whether we want to use sample packing for efficient training."
169
+ },
170
+ )
171
+ max_tokens_per_batch: int = dataclasses.field(
172
+ default=16384, metadata={"help": "Maximum number of tokens in each batch"}
173
+ )
174
+ add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
175
+ default=False,
176
+ metadata={
177
+ "help": "A flag to indicate whether we want to add end token in sample packing"
178
+ },
179
+ )
180
+ average_over_sequence: bool = dataclasses.field(
181
+ default=False,
182
+ metadata={"help": "Whether or not to average tokens per sequence"},
183
+ )
@@ -126,6 +126,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
126
126
 
127
127
 
128
128
  def perform_hyperparameter_search(
129
+ trainer_class,
129
130
  model_init: Callable,
130
131
  dataset: DatasetDict,
131
132
  data_collator: CehrGptDataCollator,
@@ -142,6 +143,7 @@ def perform_hyperparameter_search(
142
143
  After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
143
144
 
144
145
  Args:
146
+ trainer_class: A Trainer or its subclass
145
147
  model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
146
148
  dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
147
149
  data_collator (CehrGptDataCollator): A data collator for processing batches.
@@ -157,6 +159,7 @@ def perform_hyperparameter_search(
157
159
  Example:
158
160
  ```
159
161
  best_training_args = perform_hyperparameter_search(
162
+ trainer_class=Trainer,
160
163
  model_init=my_model_init,
161
164
  dataset=my_dataset_dict,
162
165
  data_collator=my_data_collator,
@@ -187,7 +190,7 @@ def perform_hyperparameter_search(
187
190
  cehrgpt_args.hyperparameter_tuning_percentage,
188
191
  training_args.seed,
189
192
  )
190
- hyperparam_trainer = Trainer(
193
+ hyperparam_trainer = trainer_class(
191
194
  model_init=model_init,
192
195
  data_collator=data_collator,
193
196
  train_dataset=sampled_train,