cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,12 @@
1
1
  import os
2
+ from functools import partial
2
3
  from typing import Optional, Union
3
4
 
5
+ import numpy as np
4
6
  import torch
7
+ import torch.distributed as dist
5
8
  from cehrbert.data_generators.hf_data_generator.meds_utils import (
9
+ CacheFileCollector,
6
10
  create_dataset_from_meds_reader,
7
11
  )
8
12
  from cehrbert.runners.hf_runner_argument_dataclass import (
@@ -16,11 +20,15 @@ from cehrbert.runners.runner_util import (
16
20
  load_parquet_as_dataset,
17
21
  )
18
22
  from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
19
- from transformers import AutoConfig, Trainer, TrainingArguments, set_seed
23
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
24
+ from transformers.trainer_utils import is_main_process
20
25
  from transformers.utils import is_flash_attn_2_available, logging
21
26
 
22
27
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
23
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
28
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
29
+ CehrGptDataCollator,
30
+ SamplePackingCehrGptDataCollator,
31
+ )
24
32
  from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
25
33
  from cehrgpt.models.config import CEHRGPTConfig
26
34
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
@@ -28,10 +36,25 @@ from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
28
36
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
29
37
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
30
38
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
39
+ from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
31
40
 
32
41
  LOG = logging.get_logger("transformers")
33
42
 
34
43
 
44
+ class CustomEarlyStoppingCallback(EarlyStoppingCallback):
45
+ def check_metric_value(self, args, state, control, metric_value):
46
+ # best_metric is set by code for load_best_model
47
+ operator = np.greater if args.greater_is_better else np.less
48
+ if state.best_metric is None or (
49
+ operator(metric_value, state.best_metric)
50
+ and abs(metric_value - state.best_metric) / state.best_metric
51
+ > self.early_stopping_threshold
52
+ ):
53
+ self.early_stopping_patience_counter = 0
54
+ else:
55
+ self.early_stopping_patience_counter += 1
56
+
57
+
35
58
  def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
36
59
  # Try to load the pretrained tokenizer
37
60
  try:
@@ -59,13 +82,16 @@ def load_and_create_tokenizer(
59
82
  f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
60
83
  f"Tried to create the tokenizer, however the dataset is not provided."
61
84
  )
85
+ LOG.info("Started training the tokenizer ...")
62
86
  tokenizer = CehrGptTokenizer.train_tokenizer(
63
87
  dataset,
64
88
  {},
65
89
  data_args,
66
90
  PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
67
91
  )
92
+ LOG.info("Finished training the tokenizer ...")
68
93
  tokenizer.save_pretrained(tokenizer_abspath)
94
+ LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
69
95
 
70
96
  return tokenizer
71
97
 
@@ -120,6 +146,7 @@ def load_and_create_model(
120
146
  pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
121
147
  else:
122
148
  pretrained_embedding_dim = model_args.hidden_size
149
+
123
150
  model_config = CEHRGPTConfig(
124
151
  vocab_size=tokenizer.vocab_size,
125
152
  value_vocab_size=tokenizer.value_vocab_size,
@@ -131,15 +158,23 @@ def load_and_create_model(
131
158
  attn_implementation=attn_implementation,
132
159
  causal_sfm=cehrgpt_args.causal_sfm,
133
160
  demographics_size=cehrgpt_args.demographics_size,
161
+ next_token_prediction_loss_weight=cehrgpt_args.next_token_prediction_loss_weight,
134
162
  lab_token_penalty=cehrgpt_args.lab_token_penalty,
135
163
  lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
164
+ value_prediction_loss_weight=cehrgpt_args.value_prediction_loss_weight,
136
165
  entropy_penalty=cehrgpt_args.entropy_penalty,
137
166
  entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
138
167
  n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
139
168
  use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
140
169
  pretrained_embedding_dim=pretrained_embedding_dim,
170
+ sample_packing_max_positions=(
171
+ cehrgpt_args.max_tokens_per_batch
172
+ if cehrgpt_args.sample_packing
173
+ else model_args.max_position_embeddings
174
+ ),
141
175
  **model_args.as_dict(),
142
176
  )
177
+
143
178
  model = CEHRGPT2LMHeadModel(model_config)
144
179
  if tokenizer.pretrained_token_ids:
145
180
  model.cehrgpt.update_pretrained_embeddings(
@@ -156,6 +191,11 @@ def load_and_create_model(
156
191
  def main():
157
192
  cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
158
193
 
194
+ if cehrgpt_args.sample_packing and data_args.streaming:
195
+ raise RuntimeError(
196
+ f"sample_packing is not supported when streaming is enabled, please set streaming to False"
197
+ )
198
+
159
199
  if data_args.streaming:
160
200
  # This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
161
201
  # This happens only when streaming is enabled
@@ -165,6 +205,8 @@ def main():
165
205
  training_args.dataloader_num_workers = 0
166
206
  training_args.dataloader_prefetch_factor = None
167
207
 
208
+ processed_dataset: Optional[DatasetDict] = None
209
+ cache_file_collector = CacheFileCollector()
168
210
  prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
169
211
  if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
170
212
  LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
@@ -200,118 +242,158 @@ def main():
200
242
  )
201
243
  cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
202
244
  else:
203
- # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
204
- if data_args.is_data_in_meds:
205
- meds_extension_path = get_meds_extension_path(
206
- data_folder=data_args.data_folder,
207
- dataset_prepared_path=data_args.dataset_prepared_path,
208
- )
209
- try:
210
- LOG.info(
211
- "Trying to load the MEDS extension from disk at %s...",
212
- meds_extension_path,
245
+ # Only run tokenization and data transformation in the main process in torch distributed training
246
+ # otherwise the multiple processes will create tokenizers at the same time
247
+ if is_main_process(training_args.local_rank):
248
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
249
+ if data_args.is_data_in_meds:
250
+ meds_extension_path = get_meds_extension_path(
251
+ data_folder=data_args.data_folder,
252
+ dataset_prepared_path=data_args.dataset_prepared_path,
213
253
  )
214
- dataset = load_from_disk(meds_extension_path)
215
- if data_args.streaming:
216
- if isinstance(dataset, DatasetDict):
217
- dataset = {
218
- k: v.to_iterable_dataset(
254
+ try:
255
+ LOG.info(
256
+ "Trying to load the MEDS extension from disk at %s...",
257
+ meds_extension_path,
258
+ )
259
+ dataset = load_from_disk(meds_extension_path)
260
+ if data_args.streaming:
261
+ if isinstance(dataset, DatasetDict):
262
+ dataset = {
263
+ k: v.to_iterable_dataset(
264
+ num_shards=training_args.dataloader_num_workers
265
+ )
266
+ for k, v in dataset.items()
267
+ }
268
+ else:
269
+ dataset = dataset.to_iterable_dataset(
219
270
  num_shards=training_args.dataloader_num_workers
220
271
  )
221
- for k, v in dataset.items()
222
- }
223
- else:
224
- dataset = dataset.to_iterable_dataset(
225
- num_shards=training_args.dataloader_num_workers
226
- )
227
- except FileNotFoundError as e:
228
- LOG.exception(e)
229
- dataset = create_dataset_from_meds_reader(
230
- data_args=data_args,
231
- dataset_mappings=[
232
- MedToCehrGPTDatasetMapping(
233
- data_args=data_args,
234
- is_pretraining=True,
235
- include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
272
+ except FileNotFoundError as e:
273
+ LOG.warning(e)
274
+ dataset = create_dataset_from_meds_reader(
275
+ data_args=data_args,
276
+ dataset_mappings=[
277
+ MedToCehrGPTDatasetMapping(
278
+ data_args=data_args,
279
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
280
+ )
281
+ ],
282
+ cache_file_collector=cache_file_collector,
283
+ )
284
+ if not data_args.streaming:
285
+ dataset.save_to_disk(str(meds_extension_path))
286
+ stats = dataset.cleanup_cache_files()
287
+ LOG.info(
288
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
289
+ stats,
236
290
  )
237
- ],
291
+ # Clean up the files created from the data generator
292
+ cache_file_collector.remove_cache_files()
293
+ dataset = load_from_disk(str(meds_extension_path))
294
+ else:
295
+ # Load the dataset from the parquet files
296
+ dataset = load_parquet_as_dataset(
297
+ os.path.expanduser(data_args.data_folder),
298
+ split="train",
299
+ streaming=data_args.streaming,
238
300
  )
239
- if not data_args.streaming:
240
- dataset.save_to_disk(str(meds_extension_path))
241
- stats = dataset.cleanup_cache_files()
242
- LOG.info(
243
- "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
244
- stats,
301
+ # If streaming is enabled, we need to manually split the data into train/val
302
+ if data_args.streaming and data_args.validation_split_num:
303
+ dataset = dataset.shuffle(
304
+ buffer_size=10_000, seed=training_args.seed
245
305
  )
246
- dataset = load_from_disk(str(meds_extension_path))
247
- else:
248
- # Load the dataset from the parquet files
249
- dataset = load_parquet_as_dataset(
250
- os.path.expanduser(data_args.data_folder),
251
- split="train",
252
- streaming=data_args.streaming,
306
+ train_set = dataset.skip(data_args.validation_split_num)
307
+ val_set = dataset.take(data_args.validation_split_num)
308
+ dataset = DatasetDict({"train": train_set, "validation": val_set})
309
+ elif data_args.validation_split_percentage:
310
+ dataset = dataset.train_test_split(
311
+ test_size=data_args.validation_split_percentage,
312
+ seed=training_args.seed,
313
+ )
314
+ dataset = DatasetDict(
315
+ {"train": dataset["train"], "validation": dataset["test"]}
316
+ )
317
+ else:
318
+ raise RuntimeError(
319
+ f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
320
+ f"defined, otherwise validation_split_percentage needs to be provided. "
321
+ f"The current values are:\n"
322
+ f"validation_split_percentage: {data_args.validation_split_percentage}\n"
323
+ f"validation_split_num: {data_args.validation_split_num}\n"
324
+ f"streaming: {data_args.streaming}"
325
+ )
326
+
327
+ # Create the CEHR-GPT tokenizer if it's not available in the output folder
328
+ cehrgpt_tokenizer = load_and_create_tokenizer(
329
+ data_args=data_args,
330
+ model_args=model_args,
331
+ cehrgpt_args=cehrgpt_args,
332
+ dataset=dataset,
253
333
  )
254
- # If streaming is enabled, we need to manually split the data into train/val
255
- if data_args.streaming and data_args.validation_split_num:
256
- dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
257
- train_set = dataset.skip(data_args.validation_split_num)
258
- val_set = dataset.take(data_args.validation_split_num)
259
- dataset = DatasetDict({"train": train_set, "test": val_set})
260
- elif data_args.validation_split_percentage:
261
- dataset = dataset.train_test_split(
262
- test_size=data_args.validation_split_percentage,
263
- seed=training_args.seed,
264
- )
265
- else:
266
- raise RuntimeError(
267
- f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
268
- f"defined, otherwise validation_split_percentage needs to be provided. "
269
- f"The current values are:\n"
270
- f"validation_split_percentage: {data_args.validation_split_percentage}\n"
271
- f"validation_split_num: {data_args.validation_split_num}\n"
272
- f"streaming: {data_args.streaming}"
273
- )
274
334
 
275
- # Create the CEHR-GPT tokenizer if it's not available in the output folder
276
- cehrgpt_tokenizer = load_and_create_tokenizer(
277
- data_args=data_args,
278
- model_args=model_args,
279
- cehrgpt_args=cehrgpt_args,
280
- dataset=dataset,
281
- )
282
- # Retrain the tokenizer in case we want to pretrain the model further using different datasets
283
- if cehrgpt_args.expand_tokenizer:
284
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
285
- try:
286
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
287
- except Exception:
288
- cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
289
- cehrgpt_tokenizer=cehrgpt_tokenizer,
290
- dataset=dataset["train"],
291
- data_args=data_args,
292
- concept_name_mapping={},
293
- pretrained_concept_embedding_model=PretrainedEmbeddings(
294
- cehrgpt_args.pretrained_embedding_path
295
- ),
296
- )
297
- cehrgpt_tokenizer.save_pretrained(
298
- os.path.expanduser(training_args.output_dir)
335
+ # Retrain the tokenizer in case we want to pretrain the model further using different datasets
336
+ if cehrgpt_args.expand_tokenizer:
337
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
338
+ try:
339
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
340
+ new_tokenizer_path
341
+ )
342
+ except Exception:
343
+ cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
344
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
345
+ dataset=dataset["train"],
346
+ data_args=data_args,
347
+ concept_name_mapping={},
348
+ pretrained_concept_embedding_model=PretrainedEmbeddings(
349
+ cehrgpt_args.pretrained_embedding_path
350
+ ),
351
+ )
352
+ cehrgpt_tokenizer.save_pretrained(
353
+ os.path.expanduser(training_args.output_dir)
354
+ )
355
+
356
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
357
+ if not data_args.streaming:
358
+ all_columns = dataset["train"].column_names
359
+ if "visit_concept_ids" in all_columns:
360
+ dataset = dataset.remove_columns(["visit_concept_ids"])
361
+
362
+ # sort the patient features chronologically and tokenize the data
363
+ processed_dataset = create_cehrgpt_pretraining_dataset(
364
+ dataset=dataset,
365
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
366
+ data_args=data_args,
367
+ cache_file_collector=cache_file_collector,
368
+ )
369
+ # only save the data to the disk if it is not streaming
370
+ if not data_args.streaming:
371
+ processed_dataset.save_to_disk(str(prepared_ds_path))
372
+ stats = processed_dataset.cleanup_cache_files()
373
+ LOG.info(
374
+ "Clean up the cached files for the cehrgpt pretraining dataset: %s",
375
+ stats,
299
376
  )
377
+ cache_file_collector.remove_cache_files()
378
+
379
+ # After main-process-only operations, synchronize all processes to ensure consistency
380
+ if dist.is_available() and dist.is_initialized():
381
+ dist.barrier()
300
382
 
301
- # sort the patient features chronologically and tokenize the data
302
- processed_dataset = create_cehrgpt_pretraining_dataset(
303
- dataset=dataset, cehrgpt_tokenizer=cehrgpt_tokenizer, data_args=data_args
383
+ # Loading tokenizer in all processes in torch distributed training
384
+ tokenizer_name_or_path = os.path.expanduser(
385
+ training_args.output_dir
386
+ if cehrgpt_args.expand_tokenizer
387
+ else model_args.tokenizer_name_or_path
304
388
  )
305
- # only save the data to the disk if it is not streaming
389
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
390
+ # Load the dataset from disk again to in torch distributed training
306
391
  if not data_args.streaming:
307
- processed_dataset.save_to_disk(str(prepared_ds_path))
308
- stats = processed_dataset.cleanup_cache_files()
309
- LOG.info(
310
- "Clean up the cached files for the cehrgpt pretraining dataset: %s",
311
- stats,
312
- )
313
392
  processed_dataset = load_from_disk(str(prepared_ds_path))
314
393
 
394
+ if processed_dataset is None:
395
+ raise RuntimeError("The processed dataset cannot be None")
396
+
315
397
  def filter_func(examples):
316
398
  if cehrgpt_args.drop_long_sequences:
317
399
  return [
@@ -369,22 +451,64 @@ def main():
369
451
  # Set seed before initializing model.
370
452
  set_seed(training_args.seed)
371
453
 
372
- if not data_args.streaming:
454
+ if not data_args.streaming and not cehrgpt_args.sample_packing:
373
455
  processed_dataset.set_format("pt")
374
456
 
375
- trainer = Trainer(
457
+ callbacks = []
458
+ if cehrgpt_args.use_early_stopping:
459
+ callbacks.append(
460
+ CustomEarlyStoppingCallback(
461
+ model_args.early_stopping_patience,
462
+ cehrgpt_args.early_stopping_threshold,
463
+ )
464
+ )
465
+
466
+ if cehrgpt_args.sample_packing:
467
+ trainer_class = partial(
468
+ SamplePackingTrainer,
469
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
470
+ max_position_embeddings=model_args.max_position_embeddings,
471
+ train_lengths=processed_dataset["train"]["num_of_concepts"],
472
+ validation_lengths=(
473
+ processed_dataset["validation"]
474
+ if "validation" in processed_dataset
475
+ else processed_dataset["test"]
476
+ )["num_of_concepts"],
477
+ )
478
+ training_args.per_device_train_batch_size = 1
479
+ training_args.per_device_eval_batch_size = 1
480
+ data_collator_fn = partial(
481
+ SamplePackingCehrGptDataCollator,
482
+ cehrgpt_args.max_tokens_per_batch,
483
+ model_args.max_position_embeddings,
484
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
485
+ )
486
+ else:
487
+ trainer_class = Trainer
488
+ data_collator_fn = CehrGptDataCollator
489
+
490
+ trainer = trainer_class(
376
491
  model=model,
377
- data_collator=CehrGptDataCollator(
492
+ data_collator=data_collator_fn(
378
493
  tokenizer=cehrgpt_tokenizer,
379
- max_length=model_args.max_position_embeddings,
494
+ max_length=(
495
+ cehrgpt_args.max_tokens_per_batch
496
+ if cehrgpt_args.sample_packing
497
+ else model_args.max_position_embeddings
498
+ ),
380
499
  shuffle_records=data_args.shuffle_records,
381
500
  include_ttv_prediction=model_args.include_ttv_prediction,
382
501
  use_sub_time_tokenization=model_args.use_sub_time_tokenization,
383
502
  include_values=model_args.include_values,
384
503
  ),
385
504
  train_dataset=processed_dataset["train"],
386
- eval_dataset=processed_dataset["test"],
505
+ eval_dataset=(
506
+ processed_dataset["validation"]
507
+ if "validation" in processed_dataset
508
+ else processed_dataset["test"]
509
+ ),
387
510
  args=training_args,
511
+ callbacks=callbacks,
388
512
  )
389
513
 
390
514
  checkpoint = None
@@ -115,6 +115,9 @@ class CehrGPTArguments:
115
115
  "help": "The lower bound of the learning rate range for hyperparameter tuning."
116
116
  },
117
117
  )
118
+ next_token_prediction_loss_weight: float = dataclasses.field(
119
+ default=1.0, metadata={"help": "The weight of the next token prediction loss"}
120
+ )
118
121
  lab_token_penalty: Optional[bool] = dataclasses.field(
119
122
  default=False,
120
123
  metadata={
@@ -125,6 +128,10 @@ class CehrGPTArguments:
125
128
  default=1.0,
126
129
  metadata={"help": "lab_token_loss_weight penalty co-efficient"},
127
130
  )
131
+ value_prediction_loss_weight: Optional[float] = dataclasses.field(
132
+ default=1.0,
133
+ metadata={"help": "The weight of the value prediction loss"},
134
+ )
128
135
  entropy_penalty: Optional[bool] = dataclasses.field(
129
136
  default=False,
130
137
  metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
@@ -139,3 +146,38 @@ class CehrGPTArguments:
139
146
  "help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
140
147
  },
141
148
  )
149
+ meds_repartition: Optional[bool] = dataclasses.field(
150
+ default=False,
151
+ metadata={
152
+ "help": "A flag to indicate whether we want to repartition the meds train tune sets"
153
+ },
154
+ )
155
+ use_early_stopping: Optional[bool] = dataclasses.field(
156
+ default=True,
157
+ metadata={"help": "A flag to indicate whether we want to use early stopping."},
158
+ )
159
+ early_stopping_threshold: Optional[float] = dataclasses.field(
160
+ default=0.01,
161
+ metadata={
162
+ "help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
163
+ },
164
+ )
165
+ sample_packing: Optional[bool] = dataclasses.field(
166
+ default=False,
167
+ metadata={
168
+ "help": "A flag to indicate whether we want to use sample packing for efficient training."
169
+ },
170
+ )
171
+ max_tokens_per_batch: int = dataclasses.field(
172
+ default=16384, metadata={"help": "Maximum number of tokens in each batch"}
173
+ )
174
+ add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
175
+ default=False,
176
+ metadata={
177
+ "help": "A flag to indicate whether we want to add end token in sample packing"
178
+ },
179
+ )
180
+ average_over_sequence: bool = dataclasses.field(
181
+ default=False,
182
+ metadata={"help": "Whether or not to average tokens per sequence"},
183
+ )
@@ -126,6 +126,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
126
126
 
127
127
 
128
128
  def perform_hyperparameter_search(
129
+ trainer_class,
129
130
  model_init: Callable,
130
131
  dataset: DatasetDict,
131
132
  data_collator: CehrGptDataCollator,
@@ -142,6 +143,7 @@ def perform_hyperparameter_search(
142
143
  After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
143
144
 
144
145
  Args:
146
+ trainer_class: A Trainer or its subclass
145
147
  model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
146
148
  dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
147
149
  data_collator (CehrGptDataCollator): A data collator for processing batches.
@@ -157,6 +159,7 @@ def perform_hyperparameter_search(
157
159
  Example:
158
160
  ```
159
161
  best_training_args = perform_hyperparameter_search(
162
+ trainer_class=Trainer,
160
163
  model_init=my_model_init,
161
164
  dataset=my_dataset_dict,
162
165
  data_collator=my_data_collator,
@@ -187,7 +190,7 @@ def perform_hyperparameter_search(
187
190
  cehrgpt_args.hyperparameter_tuning_percentage,
188
191
  training_args.seed,
189
192
  )
190
- hyperparam_trainer = Trainer(
193
+ hyperparam_trainer = trainer_class(
191
194
  model_init=model_init,
192
195
  data_collator=data_collator,
193
196
  train_dataset=sampled_train,