cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/generation/omop_converter_batch.py +3 -0
  7. cehrgpt/models/config.py +10 -0
  8. cehrgpt/models/hf_cehrgpt.py +244 -73
  9. cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
  10. cehrgpt/runners/data_utils.py +243 -0
  11. cehrgpt/runners/gpt_runner_util.py +0 -10
  12. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
  13. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
  14. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
  15. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  16. cehrgpt/runners/sample_packing_trainer.py +168 -0
  17. cehrgpt/simulations/__init__.py +0 -0
  18. cehrgpt/simulations/generate_plots.py +95 -0
  19. cehrgpt/simulations/run_simulation.sh +24 -0
  20. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  21. cehrgpt/simulations/time_token_simulation.py +177 -0
  22. cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
  23. cehrgpt/tools/linear_prob/__init__.py +0 -0
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  25. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  26. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
  27. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
  28. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  29. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  30. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,16 @@ from pathlib import Path
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
  import torch
12
- from cehrbert.data_generators.hf_data_generator.meds_utils import (
13
- create_dataset_from_meds_reader,
14
- )
12
+ import torch.distributed as dist
13
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
15
14
  from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
16
15
  from cehrbert.runners.hf_runner_argument_dataclass import (
17
- DataTrainingArguments,
18
16
  FineTuneModelType,
19
17
  ModelArguments,
20
18
  )
21
19
  from cehrbert.runners.runner_util import (
22
20
  generate_prepared_ds_path,
23
21
  get_last_hf_checkpoint,
24
- get_meds_extension_path,
25
- load_parquet_as_dataset,
26
22
  )
27
23
  from datasets import DatasetDict, concatenate_datasets, load_from_disk
28
24
  from peft import LoraConfig, PeftModel, get_peft_model
@@ -38,11 +34,15 @@ from transformers import (
38
34
  TrainingArguments,
39
35
  set_seed,
40
36
  )
41
- from transformers.tokenization_utils_base import LARGE_INTEGER
37
+ from transformers.trainer_utils import is_main_process
42
38
  from transformers.utils import is_flash_attn_2_available, logging
43
39
 
44
40
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
45
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
41
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
42
+ CehrGptDataCollator,
43
+ SamplePackingCehrGptDataCollator,
44
+ )
45
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
46
46
  from cehrgpt.models.hf_cehrgpt import (
47
47
  CEHRGPTConfig,
48
48
  CehrGptForClassification,
@@ -50,9 +50,12 @@ from cehrgpt.models.hf_cehrgpt import (
50
50
  )
51
51
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
52
52
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
53
+ from cehrgpt.runners.data_utils import prepare_finetune_dataset
53
54
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
55
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
54
56
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
55
57
  from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
58
+ from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
56
59
 
57
60
  LOG = logging.get_logger("transformers")
58
61
 
@@ -155,140 +158,6 @@ def load_finetuned_model(
155
158
  raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
156
159
 
157
160
 
158
- def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
159
- """
160
- Creates training, validation, and testing dataset splits based on specified splitting strategies.
161
-
162
- This function splits a dataset into training, validation, and test sets, using either chronological,
163
- patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
164
-
165
- - **Chronological split**: Sorts by a specified date and splits based on historical and future data.
166
- - **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
167
- - **Random split**: Performs a straightforward random split of the dataset.
168
-
169
- If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
170
- the test set is created by further splitting the validation set based on `test_eval_ratio`.
171
-
172
- Parameters:
173
- data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
174
- - `data_folder` (str): Path to the main dataset.
175
- - `test_data_folder` (str, optional): Path to an optional test dataset.
176
- - `chronological_split` (bool): Whether to split chronologically.
177
- - `split_by_patient` (bool): Whether to split by unique patient IDs.
178
- - `validation_split_percentage` (float): Percentage of data to use for validation.
179
- - `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
180
- - `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
181
- - `preprocessing_batch_size` (int): Batch size for batched operations.
182
- seed (int): Random seed for reproducibility of splits.
183
-
184
- Returns:
185
- Tuple[Dataset, Dataset, Dataset]: A tuple containing:
186
- - `train_set` (Dataset): Training split of the dataset.
187
- - `validation_set` (Dataset): Validation split of the dataset.
188
- - `test_set` (Dataset): Test split of the dataset.
189
-
190
- Raises:
191
- FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
192
- ValueError: If incompatible arguments are passed for splitting strategies.
193
-
194
- Example Usage:
195
- data_args = DataTrainingArguments(
196
- data_folder="data/",
197
- validation_split_percentage=0.1,
198
- test_eval_ratio=0.2,
199
- chronological_split=True
200
- )
201
- train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
202
- """
203
- dataset = load_parquet_as_dataset(data_args.data_folder)
204
- test_set = (
205
- None
206
- if not data_args.test_data_folder
207
- else load_parquet_as_dataset(data_args.test_data_folder)
208
- )
209
-
210
- if data_args.chronological_split:
211
- # Chronological split by sorting on `index_date`
212
- dataset = dataset.sort("index_date")
213
- total_size = len(dataset)
214
- train_end = int((1 - data_args.validation_split_percentage) * total_size)
215
-
216
- # Perform the split
217
- train_set = dataset.select(range(0, train_end))
218
- validation_set = dataset.select(range(train_end, total_size))
219
-
220
- if test_set is None:
221
- test_valid_split = validation_set.train_test_split(
222
- test_size=data_args.test_eval_ratio, seed=seed
223
- )
224
- validation_set, test_set = (
225
- test_valid_split["train"],
226
- test_valid_split["test"],
227
- )
228
-
229
- elif data_args.split_by_patient:
230
- # Patient-based split
231
- LOG.info("Using the split_by_patient strategy")
232
- unique_patient_ids = dataset.unique("person_id")
233
- LOG.info(f"There are {len(unique_patient_ids)} patients in total")
234
-
235
- np.random.seed(seed)
236
- np.random.shuffle(unique_patient_ids)
237
-
238
- train_end = int(
239
- len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
240
- )
241
- train_patient_ids = set(unique_patient_ids[:train_end])
242
-
243
- if test_set is None:
244
- validation_end = int(
245
- train_end
246
- + len(unique_patient_ids)
247
- * data_args.validation_split_percentage
248
- * data_args.test_eval_ratio
249
- )
250
- val_patient_ids = set(unique_patient_ids[train_end:validation_end])
251
- test_patient_ids = set(unique_patient_ids[validation_end:])
252
- else:
253
- val_patient_ids, test_patient_ids = (
254
- set(unique_patient_ids[train_end:]),
255
- None,
256
- )
257
-
258
- # Helper function to apply patient-based filtering
259
- def filter_by_patient_ids(patient_ids):
260
- return dataset.filter(
261
- lambda batch: [pid in patient_ids for pid in batch["person_id"]],
262
- num_proc=data_args.preprocessing_num_workers,
263
- batched=True,
264
- batch_size=data_args.preprocessing_batch_size,
265
- )
266
-
267
- # Generate splits
268
- train_set = filter_by_patient_ids(train_patient_ids)
269
- validation_set = filter_by_patient_ids(val_patient_ids)
270
- if test_set is None:
271
- test_set = filter_by_patient_ids(test_patient_ids)
272
-
273
- else:
274
- # Random split
275
- train_val = dataset.train_test_split(
276
- test_size=data_args.validation_split_percentage, seed=seed
277
- )
278
- train_set, validation_set = train_val["train"], train_val["test"]
279
-
280
- if test_set is None:
281
- test_valid_split = validation_set.train_test_split(
282
- test_size=data_args.test_eval_ratio, seed=seed
283
- )
284
- validation_set, test_set = (
285
- test_valid_split["train"],
286
- test_valid_split["test"],
287
- )
288
-
289
- return train_set, validation_set, test_set
290
-
291
-
292
161
  def model_init(
293
162
  model_args: ModelArguments,
294
163
  training_args: TrainingArguments,
@@ -363,16 +232,16 @@ def main():
363
232
  prepared_ds_path = generate_prepared_ds_path(
364
233
  data_args, model_args, data_folder=data_args.cohort_folder
365
234
  )
366
-
235
+ cache_file_collector = CacheFileCollector()
367
236
  processed_dataset = None
368
237
  if any(prepared_ds_path.glob("*")):
369
238
  LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
370
239
  processed_dataset = load_from_disk(str(prepared_ds_path))
371
240
  LOG.info("Prepared dataset loaded from disk...")
372
241
  if cehrgpt_args.expand_tokenizer:
373
- try:
242
+ if tokenizer_exists(training_args.output_dir):
374
243
  tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
375
- except Exception:
244
+ else:
376
245
  LOG.warning(
377
246
  f"CehrGptTokenizer must exist in {training_args.output_dir} "
378
247
  f"when the dataset has been processed and expand_tokenizer is set to True. "
@@ -382,81 +251,77 @@ def main():
382
251
  shutil.rmtree(prepared_ds_path)
383
252
 
384
253
  if processed_dataset is None:
385
- # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
386
- if data_args.is_data_in_meds:
387
- meds_extension_path = get_meds_extension_path(
388
- data_folder=data_args.cohort_folder,
389
- dataset_prepared_path=data_args.dataset_prepared_path,
254
+ if is_main_process(training_args.local_rank):
255
+ final_splits = prepare_finetune_dataset(
256
+ data_args, training_args, cehrgpt_args, cache_file_collector
390
257
  )
391
- try:
392
- LOG.info(
393
- f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
394
- )
395
- dataset = load_from_disk(meds_extension_path)
396
- if data_args.streaming:
397
- if isinstance(dataset, DatasetDict):
398
- dataset = {
399
- k: v.to_iterable_dataset(
400
- num_shards=training_args.dataloader_num_workers
401
- )
402
- for k, v in dataset.items()
403
- }
404
- else:
405
- dataset = dataset.to_iterable_dataset(
406
- num_shards=training_args.dataloader_num_workers
258
+ if cehrgpt_args.expand_tokenizer:
259
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
260
+ if tokenizer_exists(new_tokenizer_path):
261
+ tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
262
+ else:
263
+ # Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
264
+ # embedded in the pretrained model
265
+ pretrained_concept_embedding_model = PretrainedEmbeddings(
266
+ cehrgpt_args.pretrained_embedding_path
267
+ )
268
+ if not pretrained_concept_embedding_model.exists:
269
+ pretrained_concept_embedding_model = (
270
+ tokenizer.pretrained_concept_embedding_model
407
271
  )
408
- except Exception as e:
409
- LOG.exception(e)
410
- dataset = create_dataset_from_meds_reader(
411
- data_args, is_pretraining=False
412
- )
413
- if not data_args.streaming:
414
- dataset.save_to_disk(meds_extension_path)
415
- train_set = dataset["train"]
416
- validation_set = dataset["validation"]
417
- test_set = dataset["test"]
418
- else:
419
- train_set, validation_set, test_set = create_dataset_splits(
420
- data_args=data_args, seed=training_args.seed
421
- )
422
- # Organize them into a single DatasetDict
423
- final_splits = DatasetDict(
424
- {"train": train_set, "validation": validation_set, "test": test_set}
425
- )
426
-
427
- if cehrgpt_args.expand_tokenizer:
428
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
429
- try:
430
- tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
431
- except Exception:
432
- # Try to use the defined pretrained embeddings if exists,
433
- # Otherwise we default to the pretrained model embedded in the pretrained model
434
- pretrained_concept_embedding_model = PretrainedEmbeddings(
435
- cehrgpt_args.pretrained_embedding_path
436
- )
437
- if not pretrained_concept_embedding_model.exists:
438
- pretrained_concept_embedding_model = (
439
- tokenizer.pretrained_concept_embedding_model
272
+ tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
273
+ cehrgpt_tokenizer=tokenizer,
274
+ dataset=final_splits["train"],
275
+ data_args=data_args,
276
+ concept_name_mapping={},
277
+ pretrained_concept_embedding_model=pretrained_concept_embedding_model,
440
278
  )
441
- tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
442
- cehrgpt_tokenizer=tokenizer,
443
- dataset=final_splits["train"],
444
- data_args=data_args,
445
- concept_name_mapping={},
446
- pretrained_concept_embedding_model=pretrained_concept_embedding_model,
279
+ tokenizer.save_pretrained(
280
+ os.path.expanduser(training_args.output_dir)
281
+ )
282
+
283
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
284
+ if not data_args.streaming:
285
+ all_columns = final_splits["train"].column_names
286
+ if "visit_concept_ids" in all_columns:
287
+ final_splits = final_splits.remove_columns(["visit_concept_ids"])
288
+
289
+ processed_dataset = create_cehrgpt_finetuning_dataset(
290
+ dataset=final_splits,
291
+ cehrgpt_tokenizer=tokenizer,
292
+ data_args=data_args,
293
+ cache_file_collector=cache_file_collector,
294
+ )
295
+ if not data_args.streaming:
296
+ processed_dataset.save_to_disk(str(prepared_ds_path))
297
+ stats = processed_dataset.cleanup_cache_files()
298
+ LOG.info(
299
+ "Clean up the cached files for the cehrgpt finetuning dataset : %s",
300
+ stats,
447
301
  )
448
- tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
449
302
 
450
- processed_dataset = create_cehrgpt_finetuning_dataset(
451
- dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
303
+ # Remove any cached files if there are any
304
+ cache_file_collector.remove_cache_files()
305
+
306
+ # After main-process-only operations, synchronize all processes to ensure consistency
307
+ if dist.is_available() and dist.is_initialized():
308
+ dist.barrier()
309
+
310
+ # Loading tokenizer in all processes in torch distributed training
311
+ tokenizer_name_or_path = os.path.expanduser(
312
+ training_args.output_dir
313
+ if cehrgpt_args.expand_tokenizer
314
+ else model_args.tokenizer_name_or_path
452
315
  )
453
- if not data_args.streaming:
454
- processed_dataset.save_to_disk(prepared_ds_path)
316
+ tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
317
+ # Load the dataset from disk again to in torch distributed training
318
+ processed_dataset = load_from_disk(str(prepared_ds_path))
455
319
 
456
320
  # Set seed before initializing model.
457
321
  set_seed(training_args.seed)
458
322
 
459
- processed_dataset.set_format("pt")
323
+ if not data_args.streaming and not cehrgpt_args.sample_packing:
324
+ processed_dataset.set_format("pt")
460
325
 
461
326
  if cehrgpt_args.few_shot_predict:
462
327
  # At least we need two examples to have a validation set for early stopping
@@ -476,13 +341,40 @@ def main():
476
341
  config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
477
342
  if config.max_position_embeddings < model_args.max_position_embeddings:
478
343
  config.max_position_embeddings = model_args.max_position_embeddings
344
+
345
+ # persist this parameter in case this is overwritten by sample packing
346
+ per_device_eval_batch_size = training_args.per_device_eval_batch_size
347
+
348
+ if cehrgpt_args.sample_packing:
349
+ trainer_class = partial(
350
+ SamplePackingTrainer,
351
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
352
+ max_position_embeddings=config.max_position_embeddings,
353
+ train_lengths=processed_dataset["train"]["num_of_concepts"],
354
+ validation_lengths=processed_dataset["validation"]["num_of_concepts"],
355
+ )
356
+ training_args.per_device_train_batch_size = 1
357
+ training_args.per_device_eval_batch_size = 1
358
+ data_collator_fn = partial(
359
+ SamplePackingCehrGptDataCollator,
360
+ cehrgpt_args.max_tokens_per_batch,
361
+ config.max_position_embeddings,
362
+ )
363
+ else:
364
+ trainer_class = Trainer
365
+ data_collator_fn = CehrGptDataCollator
366
+
479
367
  # We suppress the additional learning objectives in fine-tuning
480
- data_collator = CehrGptDataCollator(
368
+ data_collator = data_collator_fn(
481
369
  tokenizer=tokenizer,
482
370
  max_length=(
483
- config.max_position_embeddings - 1
484
- if config.causal_sfm
485
- else config.max_position_embeddings
371
+ cehrgpt_args.max_tokens_per_batch
372
+ if cehrgpt_args.sample_packing
373
+ else (
374
+ config.max_position_embeddings - 1
375
+ if config.causal_sfm
376
+ else config.max_position_embeddings
377
+ )
486
378
  ),
487
379
  include_values=model_args.include_values,
488
380
  pretraining=False,
@@ -493,8 +385,8 @@ def main():
493
385
 
494
386
  if training_args.do_train:
495
387
  if cehrgpt_args.hyperparameter_tuning:
496
- model_args.early_stopping_patience = LARGE_INTEGER
497
388
  training_args = perform_hyperparameter_search(
389
+ trainer_class,
498
390
  partial(model_init, model_args, training_args, tokenizer),
499
391
  processed_dataset,
500
392
  data_collator,
@@ -502,13 +394,20 @@ def main():
502
394
  model_args,
503
395
  cehrgpt_args,
504
396
  )
397
+
398
+ if cehrgpt_args.retrain_with_full:
505
399
  # Always retrain with the full set when hyperparameter tuning is set to true
506
400
  retrain_with_full_set(
507
- model_args, training_args, tokenizer, processed_dataset, data_collator
401
+ trainer_class,
402
+ model_args,
403
+ training_args,
404
+ tokenizer,
405
+ processed_dataset,
406
+ data_collator,
508
407
  )
509
408
  else:
510
409
  # Initialize Trainer for final training on the combined train+val set
511
- trainer = Trainer(
410
+ trainer = trainer_class(
512
411
  model=model_init(model_args, training_args, tokenizer),
513
412
  data_collator=data_collator,
514
413
  args=training_args,
@@ -531,45 +430,31 @@ def main():
531
430
  trainer.save_metrics("train", metrics)
532
431
  trainer.save_state()
533
432
 
534
- # Retrain the model with full set using the num of epoches before earlying stopping
535
- if cehrgpt_args.retrain_with_full:
536
- update_num_epoch_before_early_stopping_callback = None
537
- for callback in trainer.callback_handler.callbacks:
538
- if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
539
- update_num_epoch_before_early_stopping_callback = callback
540
-
541
- if update_num_epoch_before_early_stopping_callback is None:
542
- raise RuntimeError(
543
- f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
544
- )
545
- final_num_epochs = (
546
- update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
547
- )
548
- training_args.num_train_epochs = final_num_epochs
549
- LOG.info(
550
- "Num Epochs before early stopping: %s",
551
- training_args.num_train_epochs,
552
- )
553
- retrain_with_full_set(
554
- model_args,
555
- training_args,
556
- tokenizer,
557
- processed_dataset,
558
- data_collator,
559
- )
560
-
561
433
  if training_args.do_predict:
434
+ if cehrgpt_args.sample_packing:
435
+ batch_sampler = SamplePackingBatchSampler(
436
+ lengths=processed_dataset["test"]["num_of_concepts"],
437
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
438
+ max_position_embeddings=config.max_position_embeddings,
439
+ drop_last=training_args.dataloader_drop_last,
440
+ seed=training_args.seed,
441
+ )
442
+ per_device_eval_batch_size = 1
443
+ else:
444
+ batch_sampler = None
562
445
  test_dataloader = DataLoader(
563
446
  dataset=processed_dataset["test"],
564
- batch_size=training_args.per_device_eval_batch_size,
447
+ batch_size=per_device_eval_batch_size,
565
448
  num_workers=training_args.dataloader_num_workers,
566
449
  collate_fn=data_collator,
567
450
  pin_memory=training_args.dataloader_pin_memory,
451
+ batch_sampler=batch_sampler,
568
452
  )
569
453
  do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
570
454
 
571
455
 
572
456
  def retrain_with_full_set(
457
+ trainer_class,
573
458
  model_args: ModelArguments,
574
459
  training_args: TrainingArguments,
575
460
  tokenizer: CehrGptTokenizer,
@@ -586,6 +471,7 @@ def retrain_with_full_set(
586
471
  and state information.
587
472
 
588
473
  Args:
474
+ trainer_class: Trainer or its subclass
589
475
  model_args (ModelArguments): Model configuration and hyperparameters.
590
476
  training_args (TrainingArguments): Training configuration, including output directory,
591
477
  evaluation strategy, and other training parameters.
@@ -607,7 +493,7 @@ def retrain_with_full_set(
607
493
  # Disable evaluation
608
494
  training_args.evaluation_strategy = "no"
609
495
  checkpoint = get_last_hf_checkpoint(training_args)
610
- final_trainer = Trainer(
496
+ final_trainer = trainer_class(
611
497
  model=model_init(model_args, training_args, tokenizer),
612
498
  data_collator=data_collator,
613
499
  args=training_args,
@@ -662,15 +548,15 @@ def do_predict(
662
548
  test_losses = []
663
549
  with torch.no_grad():
664
550
  for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
665
- person_ids = batch.pop("person_id").numpy().squeeze().astype(int)
666
- index_dates = (
667
- map(
668
- datetime.fromtimestamp,
669
- batch.pop("index_date").numpy().squeeze(axis=-1).tolist(),
670
- )
671
- if "index_date" in batch
672
- else None
673
- )
551
+ person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
552
+ if person_ids.ndim == 0:
553
+ person_ids = np.asarray([person_ids])
554
+
555
+ index_dates = batch.pop("index_date").numpy().squeeze()
556
+ if index_dates.ndim == 0:
557
+ index_dates = np.asarray([index_dates])
558
+ index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
559
+
674
560
  batch = {k: v.to(device) for k, v in batch.items()}
675
561
  # Forward pass
676
562
  output = model(**batch, output_attentions=False, output_hidden_states=False)
@@ -678,17 +564,25 @@ def do_predict(
678
564
 
679
565
  # Collect logits and labels for prediction
680
566
  logits = output.logits.float().cpu().numpy().squeeze()
567
+ if logits.ndim == 0:
568
+ logits = np.asarray([logits])
569
+ probabilities = sigmoid(logits)
570
+
681
571
  labels = (
682
- batch["classifier_label"].float().cpu().numpy().squeeze().astype(bool)
572
+ batch["classifier_label"].float().cpu().numpy().astype(bool).squeeze()
683
573
  )
684
- probabilities = sigmoid(logits)
574
+ if labels.ndim == 0:
575
+ labels = np.asarray([labels])
576
+
685
577
  # Save predictions to parquet file
686
578
  test_prediction_pd = pd.DataFrame(
687
579
  {
688
580
  "subject_id": person_ids,
689
581
  "prediction_time": index_dates,
690
- "boolean_prediction_probability": probabilities,
691
- "boolean_prediction": logits,
582
+ "predicted_boolean_probability": probabilities,
583
+ "predicted_boolean_value": pd.Series(
584
+ [None] * len(person_ids), dtype=bool
585
+ ),
692
586
  "boolean_value": labels,
693
587
  }
694
588
  )
@@ -702,7 +596,7 @@ def do_predict(
702
596
  # Compute metrics and save results
703
597
  metrics = compute_metrics(
704
598
  references=test_prediction_pd.boolean_value,
705
- probs=test_prediction_pd.boolean_prediction_probability,
599
+ probs=test_prediction_pd.predicted_boolean_probability,
706
600
  )
707
601
  metrics["test_loss"] = np.mean(test_losses)
708
602