cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,16 @@ from pathlib import Path
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
  import torch
12
- from cehrbert.data_generators.hf_data_generator.meds_utils import (
13
- create_dataset_from_meds_reader,
14
- )
12
+ import torch.distributed as dist
13
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
15
14
  from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
16
15
  from cehrbert.runners.hf_runner_argument_dataclass import (
17
- DataTrainingArguments,
18
16
  FineTuneModelType,
19
17
  ModelArguments,
20
18
  )
21
19
  from cehrbert.runners.runner_util import (
22
20
  generate_prepared_ds_path,
23
21
  get_last_hf_checkpoint,
24
- get_meds_extension_path,
25
- load_parquet_as_dataset,
26
22
  )
27
23
  from datasets import DatasetDict, concatenate_datasets, load_from_disk
28
24
  from peft import LoraConfig, PeftModel, get_peft_model
@@ -38,12 +34,15 @@ from transformers import (
38
34
  TrainingArguments,
39
35
  set_seed,
40
36
  )
41
- from transformers.tokenization_utils_base import LARGE_INTEGER
37
+ from transformers.trainer_utils import is_main_process
42
38
  from transformers.utils import is_flash_attn_2_available, logging
43
39
 
44
40
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
45
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
46
- from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
41
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
42
+ CehrGptDataCollator,
43
+ SamplePackingCehrGptDataCollator,
44
+ )
45
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
47
46
  from cehrgpt.models.hf_cehrgpt import (
48
47
  CEHRGPTConfig,
49
48
  CehrGptForClassification,
@@ -51,9 +50,12 @@ from cehrgpt.models.hf_cehrgpt import (
51
50
  )
52
51
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
53
52
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
53
+ from cehrgpt.runners.data_utils import prepare_finetune_dataset
54
54
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
55
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
55
56
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
56
57
  from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
58
+ from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
57
59
 
58
60
  LOG = logging.get_logger("transformers")
59
61
 
@@ -156,140 +158,6 @@ def load_finetuned_model(
156
158
  raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
157
159
 
158
160
 
159
- def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
160
- """
161
- Creates training, validation, and testing dataset splits based on specified splitting strategies.
162
-
163
- This function splits a dataset into training, validation, and test sets, using either chronological,
164
- patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
165
-
166
- - **Chronological split**: Sorts by a specified date and splits based on historical and future data.
167
- - **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
168
- - **Random split**: Performs a straightforward random split of the dataset.
169
-
170
- If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
171
- the test set is created by further splitting the validation set based on `test_eval_ratio`.
172
-
173
- Parameters:
174
- data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
175
- - `data_folder` (str): Path to the main dataset.
176
- - `test_data_folder` (str, optional): Path to an optional test dataset.
177
- - `chronological_split` (bool): Whether to split chronologically.
178
- - `split_by_patient` (bool): Whether to split by unique patient IDs.
179
- - `validation_split_percentage` (float): Percentage of data to use for validation.
180
- - `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
181
- - `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
182
- - `preprocessing_batch_size` (int): Batch size for batched operations.
183
- seed (int): Random seed for reproducibility of splits.
184
-
185
- Returns:
186
- Tuple[Dataset, Dataset, Dataset]: A tuple containing:
187
- - `train_set` (Dataset): Training split of the dataset.
188
- - `validation_set` (Dataset): Validation split of the dataset.
189
- - `test_set` (Dataset): Test split of the dataset.
190
-
191
- Raises:
192
- FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
193
- ValueError: If incompatible arguments are passed for splitting strategies.
194
-
195
- Example Usage:
196
- data_args = DataTrainingArguments(
197
- data_folder="data/",
198
- validation_split_percentage=0.1,
199
- test_eval_ratio=0.2,
200
- chronological_split=True
201
- )
202
- train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
203
- """
204
- dataset = load_parquet_as_dataset(data_args.data_folder)
205
- test_set = (
206
- None
207
- if not data_args.test_data_folder
208
- else load_parquet_as_dataset(data_args.test_data_folder)
209
- )
210
-
211
- if data_args.chronological_split:
212
- # Chronological split by sorting on `index_date`
213
- dataset = dataset.sort("index_date")
214
- total_size = len(dataset)
215
- train_end = int((1 - data_args.validation_split_percentage) * total_size)
216
-
217
- # Perform the split
218
- train_set = dataset.select(range(0, train_end))
219
- validation_set = dataset.select(range(train_end, total_size))
220
-
221
- if test_set is None:
222
- test_valid_split = validation_set.train_test_split(
223
- test_size=data_args.test_eval_ratio, seed=seed
224
- )
225
- validation_set, test_set = (
226
- test_valid_split["train"],
227
- test_valid_split["test"],
228
- )
229
-
230
- elif data_args.split_by_patient:
231
- # Patient-based split
232
- LOG.info("Using the split_by_patient strategy")
233
- unique_patient_ids = dataset.unique("person_id")
234
- LOG.info(f"There are {len(unique_patient_ids)} patients in total")
235
-
236
- np.random.seed(seed)
237
- np.random.shuffle(unique_patient_ids)
238
-
239
- train_end = int(
240
- len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
241
- )
242
- train_patient_ids = set(unique_patient_ids[:train_end])
243
-
244
- if test_set is None:
245
- validation_end = int(
246
- train_end
247
- + len(unique_patient_ids)
248
- * data_args.validation_split_percentage
249
- * data_args.test_eval_ratio
250
- )
251
- val_patient_ids = set(unique_patient_ids[train_end:validation_end])
252
- test_patient_ids = set(unique_patient_ids[validation_end:])
253
- else:
254
- val_patient_ids, test_patient_ids = (
255
- set(unique_patient_ids[train_end:]),
256
- None,
257
- )
258
-
259
- # Helper function to apply patient-based filtering
260
- def filter_by_patient_ids(patient_ids):
261
- return dataset.filter(
262
- lambda batch: [pid in patient_ids for pid in batch["person_id"]],
263
- num_proc=data_args.preprocessing_num_workers,
264
- batched=True,
265
- batch_size=data_args.preprocessing_batch_size,
266
- )
267
-
268
- # Generate splits
269
- train_set = filter_by_patient_ids(train_patient_ids)
270
- validation_set = filter_by_patient_ids(val_patient_ids)
271
- if test_set is None:
272
- test_set = filter_by_patient_ids(test_patient_ids)
273
-
274
- else:
275
- # Random split
276
- train_val = dataset.train_test_split(
277
- test_size=data_args.validation_split_percentage, seed=seed
278
- )
279
- train_set, validation_set = train_val["train"], train_val["test"]
280
-
281
- if test_set is None:
282
- test_valid_split = validation_set.train_test_split(
283
- test_size=data_args.test_eval_ratio, seed=seed
284
- )
285
- validation_set, test_set = (
286
- test_valid_split["train"],
287
- test_valid_split["test"],
288
- )
289
-
290
- return train_set, validation_set, test_set
291
-
292
-
293
161
  def model_init(
294
162
  model_args: ModelArguments,
295
163
  training_args: TrainingArguments,
@@ -364,16 +232,16 @@ def main():
364
232
  prepared_ds_path = generate_prepared_ds_path(
365
233
  data_args, model_args, data_folder=data_args.cohort_folder
366
234
  )
367
-
235
+ cache_file_collector = CacheFileCollector()
368
236
  processed_dataset = None
369
237
  if any(prepared_ds_path.glob("*")):
370
238
  LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
371
239
  processed_dataset = load_from_disk(str(prepared_ds_path))
372
240
  LOG.info("Prepared dataset loaded from disk...")
373
241
  if cehrgpt_args.expand_tokenizer:
374
- try:
242
+ if tokenizer_exists(training_args.output_dir):
375
243
  tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
376
- except Exception:
244
+ else:
377
245
  LOG.warning(
378
246
  f"CehrGptTokenizer must exist in {training_args.output_dir} "
379
247
  f"when the dataset has been processed and expand_tokenizer is set to True. "
@@ -383,101 +251,77 @@ def main():
383
251
  shutil.rmtree(prepared_ds_path)
384
252
 
385
253
  if processed_dataset is None:
386
- # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
387
- if data_args.is_data_in_meds:
388
- meds_extension_path = get_meds_extension_path(
389
- data_folder=data_args.cohort_folder,
390
- dataset_prepared_path=data_args.dataset_prepared_path,
254
+ if is_main_process(training_args.local_rank):
255
+ final_splits = prepare_finetune_dataset(
256
+ data_args, training_args, cehrgpt_args, cache_file_collector
391
257
  )
392
- try:
393
- LOG.info(
394
- f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
395
- )
396
- dataset = load_from_disk(meds_extension_path)
397
- if data_args.streaming:
398
- if isinstance(dataset, DatasetDict):
399
- dataset = {
400
- k: v.to_iterable_dataset(
401
- num_shards=training_args.dataloader_num_workers
402
- )
403
- for k, v in dataset.items()
404
- }
405
- else:
406
- dataset = dataset.to_iterable_dataset(
407
- num_shards=training_args.dataloader_num_workers
408
- )
409
- except Exception as e:
410
- LOG.exception(e)
411
- dataset = create_dataset_from_meds_reader(
412
- data_args=data_args,
413
- dataset_mappings=[
414
- MedToCehrGPTDatasetMapping(
415
- data_args=data_args,
416
- is_pretraining=False,
417
- include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
258
+ if cehrgpt_args.expand_tokenizer:
259
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
260
+ if tokenizer_exists(new_tokenizer_path):
261
+ tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
262
+ else:
263
+ # Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
264
+ # embedded in the pretrained model
265
+ pretrained_concept_embedding_model = PretrainedEmbeddings(
266
+ cehrgpt_args.pretrained_embedding_path
267
+ )
268
+ if not pretrained_concept_embedding_model.exists:
269
+ pretrained_concept_embedding_model = (
270
+ tokenizer.pretrained_concept_embedding_model
418
271
  )
419
- ],
420
- )
421
- if not data_args.streaming:
422
- dataset.save_to_disk(str(meds_extension_path))
423
- stats = dataset.cleanup_cache_files()
424
- LOG.info(
425
- "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
426
- stats,
272
+ tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
273
+ cehrgpt_tokenizer=tokenizer,
274
+ dataset=final_splits["train"],
275
+ data_args=data_args,
276
+ concept_name_mapping={},
277
+ pretrained_concept_embedding_model=pretrained_concept_embedding_model,
278
+ )
279
+ tokenizer.save_pretrained(
280
+ os.path.expanduser(training_args.output_dir)
427
281
  )
428
- dataset = load_from_disk(str(meds_extension_path))
429
282
 
430
- train_set = dataset["train"]
431
- validation_set = dataset["validation"]
432
- test_set = dataset["test"]
433
- else:
434
- train_set, validation_set, test_set = create_dataset_splits(
435
- data_args=data_args, seed=training_args.seed
283
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
284
+ if not data_args.streaming:
285
+ all_columns = final_splits["train"].column_names
286
+ if "visit_concept_ids" in all_columns:
287
+ final_splits = final_splits.remove_columns(["visit_concept_ids"])
288
+
289
+ processed_dataset = create_cehrgpt_finetuning_dataset(
290
+ dataset=final_splits,
291
+ cehrgpt_tokenizer=tokenizer,
292
+ data_args=data_args,
293
+ cache_file_collector=cache_file_collector,
436
294
  )
437
- # Organize them into a single DatasetDict
438
- final_splits = DatasetDict(
439
- {"train": train_set, "validation": validation_set, "test": test_set}
440
- )
441
-
442
- if cehrgpt_args.expand_tokenizer:
443
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
444
- try:
445
- tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
446
- except Exception:
447
- # Try to use the defined pretrained embeddings if exists,
448
- # Otherwise we default to the pretrained model embedded in the pretrained model
449
- pretrained_concept_embedding_model = PretrainedEmbeddings(
450
- cehrgpt_args.pretrained_embedding_path
451
- )
452
- if not pretrained_concept_embedding_model.exists:
453
- pretrained_concept_embedding_model = (
454
- tokenizer.pretrained_concept_embedding_model
455
- )
456
- tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
457
- cehrgpt_tokenizer=tokenizer,
458
- dataset=final_splits["train"],
459
- data_args=data_args,
460
- concept_name_mapping={},
461
- pretrained_concept_embedding_model=pretrained_concept_embedding_model,
295
+ if not data_args.streaming:
296
+ processed_dataset.save_to_disk(str(prepared_ds_path))
297
+ stats = processed_dataset.cleanup_cache_files()
298
+ LOG.info(
299
+ "Clean up the cached files for the cehrgpt finetuning dataset : %s",
300
+ stats,
462
301
  )
463
- tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
464
302
 
465
- processed_dataset = create_cehrgpt_finetuning_dataset(
466
- dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
303
+ # Remove any cached files if there are any
304
+ cache_file_collector.remove_cache_files()
305
+
306
+ # After main-process-only operations, synchronize all processes to ensure consistency
307
+ if dist.is_available() and dist.is_initialized():
308
+ dist.barrier()
309
+
310
+ # Loading tokenizer in all processes in torch distributed training
311
+ tokenizer_name_or_path = os.path.expanduser(
312
+ training_args.output_dir
313
+ if cehrgpt_args.expand_tokenizer
314
+ else model_args.tokenizer_name_or_path
467
315
  )
468
- if not data_args.streaming:
469
- processed_dataset.save_to_disk(str(prepared_ds_path))
470
- stats = processed_dataset.cleanup_cache_files()
471
- LOG.info(
472
- "Clean up the cached files for the cehrgpt finetuning dataset : %s",
473
- stats,
474
- )
475
- processed_dataset = load_from_disk(str(prepared_ds_path))
316
+ tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
317
+ # Load the dataset from disk again to in torch distributed training
318
+ processed_dataset = load_from_disk(str(prepared_ds_path))
476
319
 
477
320
  # Set seed before initializing model.
478
321
  set_seed(training_args.seed)
479
322
 
480
- processed_dataset.set_format("pt")
323
+ if not data_args.streaming and not cehrgpt_args.sample_packing:
324
+ processed_dataset.set_format("pt")
481
325
 
482
326
  if cehrgpt_args.few_shot_predict:
483
327
  # At least we need two examples to have a validation set for early stopping
@@ -497,13 +341,40 @@ def main():
497
341
  config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
498
342
  if config.max_position_embeddings < model_args.max_position_embeddings:
499
343
  config.max_position_embeddings = model_args.max_position_embeddings
344
+
345
+ # persist this parameter in case this is overwritten by sample packing
346
+ per_device_eval_batch_size = training_args.per_device_eval_batch_size
347
+
348
+ if cehrgpt_args.sample_packing:
349
+ trainer_class = partial(
350
+ SamplePackingTrainer,
351
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
352
+ max_position_embeddings=config.max_position_embeddings,
353
+ train_lengths=processed_dataset["train"]["num_of_concepts"],
354
+ validation_lengths=processed_dataset["validation"]["num_of_concepts"],
355
+ )
356
+ training_args.per_device_train_batch_size = 1
357
+ training_args.per_device_eval_batch_size = 1
358
+ data_collator_fn = partial(
359
+ SamplePackingCehrGptDataCollator,
360
+ cehrgpt_args.max_tokens_per_batch,
361
+ config.max_position_embeddings,
362
+ )
363
+ else:
364
+ trainer_class = Trainer
365
+ data_collator_fn = CehrGptDataCollator
366
+
500
367
  # We suppress the additional learning objectives in fine-tuning
501
- data_collator = CehrGptDataCollator(
368
+ data_collator = data_collator_fn(
502
369
  tokenizer=tokenizer,
503
370
  max_length=(
504
- config.max_position_embeddings - 1
505
- if config.causal_sfm
506
- else config.max_position_embeddings
371
+ cehrgpt_args.max_tokens_per_batch
372
+ if cehrgpt_args.sample_packing
373
+ else (
374
+ config.max_position_embeddings - 1
375
+ if config.causal_sfm
376
+ else config.max_position_embeddings
377
+ )
507
378
  ),
508
379
  include_values=model_args.include_values,
509
380
  pretraining=False,
@@ -514,8 +385,8 @@ def main():
514
385
 
515
386
  if training_args.do_train:
516
387
  if cehrgpt_args.hyperparameter_tuning:
517
- model_args.early_stopping_patience = LARGE_INTEGER
518
388
  training_args = perform_hyperparameter_search(
389
+ trainer_class,
519
390
  partial(model_init, model_args, training_args, tokenizer),
520
391
  processed_dataset,
521
392
  data_collator,
@@ -523,13 +394,20 @@ def main():
523
394
  model_args,
524
395
  cehrgpt_args,
525
396
  )
397
+
398
+ if cehrgpt_args.retrain_with_full:
526
399
  # Always retrain with the full set when hyperparameter tuning is set to true
527
400
  retrain_with_full_set(
528
- model_args, training_args, tokenizer, processed_dataset, data_collator
401
+ trainer_class,
402
+ model_args,
403
+ training_args,
404
+ tokenizer,
405
+ processed_dataset,
406
+ data_collator,
529
407
  )
530
408
  else:
531
409
  # Initialize Trainer for final training on the combined train+val set
532
- trainer = Trainer(
410
+ trainer = trainer_class(
533
411
  model=model_init(model_args, training_args, tokenizer),
534
412
  data_collator=data_collator,
535
413
  args=training_args,
@@ -552,45 +430,31 @@ def main():
552
430
  trainer.save_metrics("train", metrics)
553
431
  trainer.save_state()
554
432
 
555
- # Retrain the model with full set using the num of epoches before earlying stopping
556
- if cehrgpt_args.retrain_with_full:
557
- update_num_epoch_before_early_stopping_callback = None
558
- for callback in trainer.callback_handler.callbacks:
559
- if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
560
- update_num_epoch_before_early_stopping_callback = callback
561
-
562
- if update_num_epoch_before_early_stopping_callback is None:
563
- raise RuntimeError(
564
- f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
565
- )
566
- final_num_epochs = (
567
- update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
568
- )
569
- training_args.num_train_epochs = final_num_epochs
570
- LOG.info(
571
- "Num Epochs before early stopping: %s",
572
- training_args.num_train_epochs,
573
- )
574
- retrain_with_full_set(
575
- model_args,
576
- training_args,
577
- tokenizer,
578
- processed_dataset,
579
- data_collator,
580
- )
581
-
582
433
  if training_args.do_predict:
434
+ if cehrgpt_args.sample_packing:
435
+ batch_sampler = SamplePackingBatchSampler(
436
+ lengths=processed_dataset["test"]["num_of_concepts"],
437
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
438
+ max_position_embeddings=config.max_position_embeddings,
439
+ drop_last=training_args.dataloader_drop_last,
440
+ seed=training_args.seed,
441
+ )
442
+ per_device_eval_batch_size = 1
443
+ else:
444
+ batch_sampler = None
583
445
  test_dataloader = DataLoader(
584
446
  dataset=processed_dataset["test"],
585
- batch_size=training_args.per_device_eval_batch_size,
447
+ batch_size=per_device_eval_batch_size,
586
448
  num_workers=training_args.dataloader_num_workers,
587
449
  collate_fn=data_collator,
588
450
  pin_memory=training_args.dataloader_pin_memory,
451
+ batch_sampler=batch_sampler,
589
452
  )
590
453
  do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
591
454
 
592
455
 
593
456
  def retrain_with_full_set(
457
+ trainer_class,
594
458
  model_args: ModelArguments,
595
459
  training_args: TrainingArguments,
596
460
  tokenizer: CehrGptTokenizer,
@@ -607,6 +471,7 @@ def retrain_with_full_set(
607
471
  and state information.
608
472
 
609
473
  Args:
474
+ trainer_class: Trainer or its subclass
610
475
  model_args (ModelArguments): Model configuration and hyperparameters.
611
476
  training_args (TrainingArguments): Training configuration, including output directory,
612
477
  evaluation strategy, and other training parameters.
@@ -628,7 +493,7 @@ def retrain_with_full_set(
628
493
  # Disable evaluation
629
494
  training_args.evaluation_strategy = "no"
630
495
  checkpoint = get_last_hf_checkpoint(training_args)
631
- final_trainer = Trainer(
496
+ final_trainer = trainer_class(
632
497
  model=model_init(model_args, training_args, tokenizer),
633
498
  data_collator=data_collator,
634
499
  args=training_args,
@@ -683,15 +548,15 @@ def do_predict(
683
548
  test_losses = []
684
549
  with torch.no_grad():
685
550
  for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
686
- person_ids = batch.pop("person_id").numpy().squeeze().astype(int)
687
- index_dates = (
688
- map(
689
- datetime.fromtimestamp,
690
- batch.pop("index_date").numpy().squeeze(axis=-1).tolist(),
691
- )
692
- if "index_date" in batch
693
- else None
694
- )
551
+ person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
552
+ if person_ids.ndim == 0:
553
+ person_ids = np.asarray([person_ids])
554
+
555
+ index_dates = batch.pop("index_date").numpy().squeeze()
556
+ if index_dates.ndim == 0:
557
+ index_dates = np.asarray([index_dates])
558
+ index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
559
+
695
560
  batch = {k: v.to(device) for k, v in batch.items()}
696
561
  # Forward pass
697
562
  output = model(**batch, output_attentions=False, output_hidden_states=False)
@@ -699,17 +564,25 @@ def do_predict(
699
564
 
700
565
  # Collect logits and labels for prediction
701
566
  logits = output.logits.float().cpu().numpy().squeeze()
567
+ if logits.ndim == 0:
568
+ logits = np.asarray([logits])
569
+ probabilities = sigmoid(logits)
570
+
702
571
  labels = (
703
- batch["classifier_label"].float().cpu().numpy().squeeze().astype(bool)
572
+ batch["classifier_label"].float().cpu().numpy().astype(bool).squeeze()
704
573
  )
705
- probabilities = sigmoid(logits)
574
+ if labels.ndim == 0:
575
+ labels = np.asarray([labels])
576
+
706
577
  # Save predictions to parquet file
707
578
  test_prediction_pd = pd.DataFrame(
708
579
  {
709
580
  "subject_id": person_ids,
710
581
  "prediction_time": index_dates,
711
- "boolean_prediction_probability": probabilities,
712
- "boolean_prediction": logits,
582
+ "predicted_boolean_probability": probabilities,
583
+ "predicted_boolean_value": pd.Series(
584
+ [None] * len(person_ids), dtype=bool
585
+ ),
713
586
  "boolean_value": labels,
714
587
  }
715
588
  )
@@ -723,7 +596,7 @@ def do_predict(
723
596
  # Compute metrics and save results
724
597
  metrics = compute_metrics(
725
598
  references=test_prediction_pd.boolean_value,
726
- probs=test_prediction_pd.boolean_prediction_probability,
599
+ probs=test_prediction_pd.predicted_boolean_probability,
727
600
  )
728
601
  metrics["test_loss"] = np.mean(test_losses)
729
602