cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,16 @@ from pathlib import Path
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
  import torch
12
- from cehrbert.data_generators.hf_data_generator.meds_utils import (
13
- create_dataset_from_meds_reader,
14
- )
12
+ import torch.distributed as dist
13
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
15
14
  from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
16
15
  from cehrbert.runners.hf_runner_argument_dataclass import (
17
- DataTrainingArguments,
18
16
  FineTuneModelType,
19
17
  ModelArguments,
20
18
  )
21
19
  from cehrbert.runners.runner_util import (
22
20
  generate_prepared_ds_path,
23
21
  get_last_hf_checkpoint,
24
- get_meds_extension_path,
25
- load_parquet_as_dataset,
26
22
  )
27
23
  from datasets import DatasetDict, concatenate_datasets, load_from_disk
28
24
  from peft import LoraConfig, PeftModel, get_peft_model
@@ -38,12 +34,15 @@ from transformers import (
38
34
  TrainingArguments,
39
35
  set_seed,
40
36
  )
41
- from transformers.tokenization_utils_base import LARGE_INTEGER
37
+ from transformers.trainer_utils import is_main_process
42
38
  from transformers.utils import is_flash_attn_2_available, logging
43
39
 
44
40
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
45
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
46
- from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
41
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
42
+ CehrGptDataCollator,
43
+ SamplePackingCehrGptDataCollator,
44
+ )
45
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
47
46
  from cehrgpt.models.hf_cehrgpt import (
48
47
  CEHRGPTConfig,
49
48
  CehrGptForClassification,
@@ -51,9 +50,16 @@ from cehrgpt.models.hf_cehrgpt import (
51
50
  )
52
51
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
53
52
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
53
+ from cehrgpt.runners.data_utils import (
54
+ extract_cohort_sequences,
55
+ get_torch_dtype,
56
+ prepare_finetune_dataset,
57
+ )
54
58
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
59
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
55
60
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
56
61
  from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
62
+ from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
57
63
 
58
64
  LOG = logging.get_logger("transformers")
59
65
 
@@ -140,11 +146,10 @@ def load_finetuned_model(
140
146
  raise ValueError(
141
147
  f"finetune_model_type can be one of the following types {FineTuneModelType.POOLING.value}"
142
148
  )
143
-
144
149
  attn_implementation = (
145
150
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
146
151
  )
147
- torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
152
+ torch_dtype = get_torch_dtype(model_args.torch_dtype)
148
153
  # Try to create a new model based on the base model
149
154
  try:
150
155
  return finetune_model_cls.from_pretrained(
@@ -156,148 +161,25 @@ def load_finetuned_model(
156
161
  raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
157
162
 
158
163
 
159
- def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
160
- """
161
- Creates training, validation, and testing dataset splits based on specified splitting strategies.
162
-
163
- This function splits a dataset into training, validation, and test sets, using either chronological,
164
- patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
165
-
166
- - **Chronological split**: Sorts by a specified date and splits based on historical and future data.
167
- - **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
168
- - **Random split**: Performs a straightforward random split of the dataset.
169
-
170
- If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
171
- the test set is created by further splitting the validation set based on `test_eval_ratio`.
172
-
173
- Parameters:
174
- data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
175
- - `data_folder` (str): Path to the main dataset.
176
- - `test_data_folder` (str, optional): Path to an optional test dataset.
177
- - `chronological_split` (bool): Whether to split chronologically.
178
- - `split_by_patient` (bool): Whether to split by unique patient IDs.
179
- - `validation_split_percentage` (float): Percentage of data to use for validation.
180
- - `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
181
- - `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
182
- - `preprocessing_batch_size` (int): Batch size for batched operations.
183
- seed (int): Random seed for reproducibility of splits.
184
-
185
- Returns:
186
- Tuple[Dataset, Dataset, Dataset]: A tuple containing:
187
- - `train_set` (Dataset): Training split of the dataset.
188
- - `validation_set` (Dataset): Validation split of the dataset.
189
- - `test_set` (Dataset): Test split of the dataset.
190
-
191
- Raises:
192
- FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
193
- ValueError: If incompatible arguments are passed for splitting strategies.
194
-
195
- Example Usage:
196
- data_args = DataTrainingArguments(
197
- data_folder="data/",
198
- validation_split_percentage=0.1,
199
- test_eval_ratio=0.2,
200
- chronological_split=True
201
- )
202
- train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
203
- """
204
- dataset = load_parquet_as_dataset(data_args.data_folder)
205
- test_set = (
206
- None
207
- if not data_args.test_data_folder
208
- else load_parquet_as_dataset(data_args.test_data_folder)
209
- )
210
-
211
- if data_args.chronological_split:
212
- # Chronological split by sorting on `index_date`
213
- dataset = dataset.sort("index_date")
214
- total_size = len(dataset)
215
- train_end = int((1 - data_args.validation_split_percentage) * total_size)
216
-
217
- # Perform the split
218
- train_set = dataset.select(range(0, train_end))
219
- validation_set = dataset.select(range(train_end, total_size))
220
-
221
- if test_set is None:
222
- test_valid_split = validation_set.train_test_split(
223
- test_size=data_args.test_eval_ratio, seed=seed
224
- )
225
- validation_set, test_set = (
226
- test_valid_split["train"],
227
- test_valid_split["test"],
228
- )
229
-
230
- elif data_args.split_by_patient:
231
- # Patient-based split
232
- LOG.info("Using the split_by_patient strategy")
233
- unique_patient_ids = dataset.unique("person_id")
234
- LOG.info(f"There are {len(unique_patient_ids)} patients in total")
235
-
236
- np.random.seed(seed)
237
- np.random.shuffle(unique_patient_ids)
238
-
239
- train_end = int(
240
- len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
241
- )
242
- train_patient_ids = set(unique_patient_ids[:train_end])
243
-
244
- if test_set is None:
245
- validation_end = int(
246
- train_end
247
- + len(unique_patient_ids)
248
- * data_args.validation_split_percentage
249
- * data_args.test_eval_ratio
250
- )
251
- val_patient_ids = set(unique_patient_ids[train_end:validation_end])
252
- test_patient_ids = set(unique_patient_ids[validation_end:])
253
- else:
254
- val_patient_ids, test_patient_ids = (
255
- set(unique_patient_ids[train_end:]),
256
- None,
257
- )
258
-
259
- # Helper function to apply patient-based filtering
260
- def filter_by_patient_ids(patient_ids):
261
- return dataset.filter(
262
- lambda batch: [pid in patient_ids for pid in batch["person_id"]],
263
- num_proc=data_args.preprocessing_num_workers,
264
- batched=True,
265
- batch_size=data_args.preprocessing_batch_size,
266
- )
267
-
268
- # Generate splits
269
- train_set = filter_by_patient_ids(train_patient_ids)
270
- validation_set = filter_by_patient_ids(val_patient_ids)
271
- if test_set is None:
272
- test_set = filter_by_patient_ids(test_patient_ids)
273
-
274
- else:
275
- # Random split
276
- train_val = dataset.train_test_split(
277
- test_size=data_args.validation_split_percentage, seed=seed
278
- )
279
- train_set, validation_set = train_val["train"], train_val["test"]
280
-
281
- if test_set is None:
282
- test_valid_split = validation_set.train_test_split(
283
- test_size=data_args.test_eval_ratio, seed=seed
284
- )
285
- validation_set, test_set = (
286
- test_valid_split["train"],
287
- test_valid_split["test"],
288
- )
289
-
290
- return train_set, validation_set, test_set
291
-
292
-
293
164
  def model_init(
294
165
  model_args: ModelArguments,
295
166
  training_args: TrainingArguments,
167
+ cehrgpt_args: CehrGPTArguments,
296
168
  tokenizer: CehrGptTokenizer,
297
169
  ):
298
170
  model = load_finetuned_model(
299
171
  model_args, training_args, model_args.model_name_or_path
300
172
  )
173
+
174
+ if cehrgpt_args.class_weights:
175
+ model.config.class_weights = cehrgpt_args.class_weights
176
+ LOG.info(f"Setting class_weights to {model.config.class_weights}")
177
+
178
+ # Enable position embeddings when position embeddings are disabled in pre-training
179
+ if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
180
+ LOG.info(f"Enable the position_embeddings")
181
+ model.cehrgpt.enable_position_embeddings()
182
+
301
183
  if model.config.max_position_embeddings < model_args.max_position_embeddings:
302
184
  LOG.info(
303
185
  f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
@@ -307,9 +189,6 @@ def model_init(
307
189
  # Enable include_values when include_values is set to be False during pre-training
308
190
  if model_args.include_values and not model.cehrgpt.include_values:
309
191
  model.cehrgpt.include_values = True
310
- # Enable position embeddings when position embeddings are disabled in pre-training
311
- if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
312
- model.cehrgpt.exclude_position_ids = False
313
192
  # Expand tokenizer to adapt to the finetuning dataset
314
193
  if model.config.vocab_size < tokenizer.vocab_size:
315
194
  model.resize_token_embeddings(tokenizer.vocab_size)
@@ -327,6 +206,7 @@ def model_init(
327
206
  model.cehrgpt.update_pretrained_embeddings(
328
207
  tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
329
208
  )
209
+
330
210
  # Expand value tokenizer to adapt to the fine-tuning dataset
331
211
  if model.config.include_values:
332
212
  if model.config.value_vocab_size < tokenizer.value_vocab_size:
@@ -364,16 +244,16 @@ def main():
364
244
  prepared_ds_path = generate_prepared_ds_path(
365
245
  data_args, model_args, data_folder=data_args.cohort_folder
366
246
  )
367
-
247
+ cache_file_collector = CacheFileCollector()
368
248
  processed_dataset = None
369
249
  if any(prepared_ds_path.glob("*")):
370
250
  LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
371
251
  processed_dataset = load_from_disk(str(prepared_ds_path))
372
252
  LOG.info("Prepared dataset loaded from disk...")
373
253
  if cehrgpt_args.expand_tokenizer:
374
- try:
254
+ if tokenizer_exists(training_args.output_dir):
375
255
  tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
376
- except Exception:
256
+ else:
377
257
  LOG.warning(
378
258
  f"CehrGptTokenizer must exist in {training_args.output_dir} "
379
259
  f"when the dataset has been processed and expand_tokenizer is set to True. "
@@ -383,101 +263,86 @@ def main():
383
263
  shutil.rmtree(prepared_ds_path)
384
264
 
385
265
  if processed_dataset is None:
386
- # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
387
- if data_args.is_data_in_meds:
388
- meds_extension_path = get_meds_extension_path(
389
- data_folder=data_args.cohort_folder,
390
- dataset_prepared_path=data_args.dataset_prepared_path,
391
- )
392
- try:
393
- LOG.info(
394
- f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
266
+ if is_main_process(training_args.local_rank):
267
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
268
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
269
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
270
+ processed_dataset = extract_cohort_sequences(
271
+ data_args, cehrgpt_args, cache_file_collector
395
272
  )
396
- dataset = load_from_disk(meds_extension_path)
397
- if data_args.streaming:
398
- if isinstance(dataset, DatasetDict):
399
- dataset = {
400
- k: v.to_iterable_dataset(
401
- num_shards=training_args.dataloader_num_workers
402
- )
403
- for k, v in dataset.items()
404
- }
273
+ else:
274
+ final_splits = prepare_finetune_dataset(
275
+ data_args, training_args, cehrgpt_args, cache_file_collector
276
+ )
277
+ if cehrgpt_args.expand_tokenizer:
278
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
279
+ if tokenizer_exists(new_tokenizer_path):
280
+ tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
405
281
  else:
406
- dataset = dataset.to_iterable_dataset(
407
- num_shards=training_args.dataloader_num_workers
282
+ # Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
283
+ # embedded in the pretrained model
284
+ pretrained_concept_embedding_model = PretrainedEmbeddings(
285
+ cehrgpt_args.pretrained_embedding_path
408
286
  )
409
- except Exception as e:
410
- LOG.exception(e)
411
- dataset = create_dataset_from_meds_reader(
412
- data_args=data_args,
413
- dataset_mappings=[
414
- MedToCehrGPTDatasetMapping(
287
+ if not pretrained_concept_embedding_model.exists:
288
+ pretrained_concept_embedding_model = (
289
+ tokenizer.pretrained_concept_embedding_model
290
+ )
291
+ tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
292
+ cehrgpt_tokenizer=tokenizer,
293
+ dataset=final_splits["train"],
415
294
  data_args=data_args,
416
- is_pretraining=False,
417
- include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
295
+ concept_name_mapping={},
296
+ pretrained_concept_embedding_model=pretrained_concept_embedding_model,
418
297
  )
419
- ],
420
- )
298
+ tokenizer.save_pretrained(
299
+ os.path.expanduser(training_args.output_dir)
300
+ )
301
+
302
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
421
303
  if not data_args.streaming:
422
- dataset.save_to_disk(str(meds_extension_path))
423
- stats = dataset.cleanup_cache_files()
424
- LOG.info(
425
- "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
426
- stats,
427
- )
428
- dataset = load_from_disk(str(meds_extension_path))
429
-
430
- train_set = dataset["train"]
431
- validation_set = dataset["validation"]
432
- test_set = dataset["test"]
433
- else:
434
- train_set, validation_set, test_set = create_dataset_splits(
435
- data_args=data_args, seed=training_args.seed
436
- )
437
- # Organize them into a single DatasetDict
438
- final_splits = DatasetDict(
439
- {"train": train_set, "validation": validation_set, "test": test_set}
440
- )
304
+ all_columns = final_splits["train"].column_names
305
+ if "visit_concept_ids" in all_columns:
306
+ final_splits = final_splits.remove_columns(
307
+ ["visit_concept_ids"]
308
+ )
441
309
 
442
- if cehrgpt_args.expand_tokenizer:
443
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
444
- try:
445
- tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
446
- except Exception:
447
- # Try to use the defined pretrained embeddings if exists,
448
- # Otherwise we default to the pretrained model embedded in the pretrained model
449
- pretrained_concept_embedding_model = PretrainedEmbeddings(
450
- cehrgpt_args.pretrained_embedding_path
451
- )
452
- if not pretrained_concept_embedding_model.exists:
453
- pretrained_concept_embedding_model = (
454
- tokenizer.pretrained_concept_embedding_model
455
- )
456
- tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
310
+ processed_dataset = create_cehrgpt_finetuning_dataset(
311
+ dataset=final_splits,
457
312
  cehrgpt_tokenizer=tokenizer,
458
- dataset=final_splits["train"],
459
313
  data_args=data_args,
460
- concept_name_mapping={},
461
- pretrained_concept_embedding_model=pretrained_concept_embedding_model,
314
+ cache_file_collector=cache_file_collector,
462
315
  )
463
- tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
316
+ if not data_args.streaming:
317
+ processed_dataset.save_to_disk(str(prepared_ds_path))
318
+ stats = processed_dataset.cleanup_cache_files()
319
+ LOG.info(
320
+ "Clean up the cached files for the cehrgpt finetuning dataset : %s",
321
+ stats,
322
+ )
323
+
324
+ # Remove any cached files if there are any
325
+ cache_file_collector.remove_cache_files()
464
326
 
465
- processed_dataset = create_cehrgpt_finetuning_dataset(
466
- dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
327
+ # After main-process-only operations, synchronize all processes to ensure consistency
328
+ if dist.is_available() and dist.is_initialized():
329
+ dist.barrier()
330
+
331
+ # Loading tokenizer in all processes in torch distributed training
332
+ tokenizer_name_or_path = os.path.expanduser(
333
+ training_args.output_dir
334
+ if cehrgpt_args.expand_tokenizer
335
+ else model_args.tokenizer_name_or_path
467
336
  )
468
- if not data_args.streaming:
469
- processed_dataset.save_to_disk(str(prepared_ds_path))
470
- stats = processed_dataset.cleanup_cache_files()
471
- LOG.info(
472
- "Clean up the cached files for the cehrgpt finetuning dataset : %s",
473
- stats,
474
- )
475
- processed_dataset = load_from_disk(str(prepared_ds_path))
337
+ tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
338
+ # Load the dataset from disk again to in torch distributed training
339
+ processed_dataset = load_from_disk(str(prepared_ds_path))
476
340
 
477
341
  # Set seed before initializing model.
478
342
  set_seed(training_args.seed)
479
343
 
480
- processed_dataset.set_format("pt")
344
+ if not data_args.streaming and not cehrgpt_args.sample_packing:
345
+ processed_dataset.set_format("pt")
481
346
 
482
347
  if cehrgpt_args.few_shot_predict:
483
348
  # At least we need two examples to have a validation set for early stopping
@@ -497,40 +362,76 @@ def main():
497
362
  config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
498
363
  if config.max_position_embeddings < model_args.max_position_embeddings:
499
364
  config.max_position_embeddings = model_args.max_position_embeddings
365
+
366
+ # persist this parameter in case this is overwritten by sample packing
367
+ per_device_eval_batch_size = training_args.per_device_eval_batch_size
368
+
369
+ if cehrgpt_args.sample_packing:
370
+ trainer_class = partial(
371
+ SamplePackingTrainer,
372
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
373
+ max_position_embeddings=config.max_position_embeddings,
374
+ negative_sampling_probability=cehrgpt_args.negative_sampling_probability,
375
+ )
376
+ training_args.per_device_train_batch_size = 1
377
+ training_args.per_device_eval_batch_size = 1
378
+ data_collator_fn = partial(
379
+ SamplePackingCehrGptDataCollator,
380
+ cehrgpt_args.max_tokens_per_batch,
381
+ config.max_position_embeddings,
382
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
383
+ )
384
+ else:
385
+ trainer_class = Trainer
386
+ data_collator_fn = CehrGptDataCollator
387
+
500
388
  # We suppress the additional learning objectives in fine-tuning
501
- data_collator = CehrGptDataCollator(
389
+ data_collator = data_collator_fn(
502
390
  tokenizer=tokenizer,
503
391
  max_length=(
504
- config.max_position_embeddings - 1
505
- if config.causal_sfm
506
- else config.max_position_embeddings
392
+ cehrgpt_args.max_tokens_per_batch
393
+ if cehrgpt_args.sample_packing
394
+ else (
395
+ config.max_position_embeddings - 1
396
+ if config.causal_sfm
397
+ else config.max_position_embeddings
398
+ )
507
399
  ),
508
400
  include_values=model_args.include_values,
509
401
  pretraining=False,
510
402
  include_ttv_prediction=False,
511
403
  use_sub_time_tokenization=False,
512
404
  include_demographics=cehrgpt_args.include_demographics,
405
+ add_linear_prob_token=True,
513
406
  )
514
407
 
515
408
  if training_args.do_train:
516
409
  if cehrgpt_args.hyperparameter_tuning:
517
- model_args.early_stopping_patience = LARGE_INTEGER
518
410
  training_args = perform_hyperparameter_search(
519
- partial(model_init, model_args, training_args, tokenizer),
411
+ trainer_class,
412
+ partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
520
413
  processed_dataset,
521
414
  data_collator,
522
415
  training_args,
523
416
  model_args,
524
417
  cehrgpt_args,
525
418
  )
419
+
420
+ if cehrgpt_args.retrain_with_full:
526
421
  # Always retrain with the full set when hyperparameter tuning is set to true
527
422
  retrain_with_full_set(
528
- model_args, training_args, tokenizer, processed_dataset, data_collator
423
+ trainer_class,
424
+ model_args,
425
+ training_args,
426
+ cehrgpt_args,
427
+ tokenizer,
428
+ processed_dataset,
429
+ data_collator,
529
430
  )
530
431
  else:
531
432
  # Initialize Trainer for final training on the combined train+val set
532
- trainer = Trainer(
533
- model=model_init(model_args, training_args, tokenizer),
433
+ trainer = trainer_class(
434
+ model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
534
435
  data_collator=data_collator,
535
436
  args=training_args,
536
437
  train_dataset=processed_dataset["train"],
@@ -552,47 +453,34 @@ def main():
552
453
  trainer.save_metrics("train", metrics)
553
454
  trainer.save_state()
554
455
 
555
- # Retrain the model with full set using the num of epoches before earlying stopping
556
- if cehrgpt_args.retrain_with_full:
557
- update_num_epoch_before_early_stopping_callback = None
558
- for callback in trainer.callback_handler.callbacks:
559
- if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
560
- update_num_epoch_before_early_stopping_callback = callback
561
-
562
- if update_num_epoch_before_early_stopping_callback is None:
563
- raise RuntimeError(
564
- f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
565
- )
566
- final_num_epochs = (
567
- update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
568
- )
569
- training_args.num_train_epochs = final_num_epochs
570
- LOG.info(
571
- "Num Epochs before early stopping: %s",
572
- training_args.num_train_epochs,
573
- )
574
- retrain_with_full_set(
575
- model_args,
576
- training_args,
577
- tokenizer,
578
- processed_dataset,
579
- data_collator,
580
- )
581
-
582
456
  if training_args.do_predict:
457
+ if cehrgpt_args.sample_packing:
458
+ batch_sampler = SamplePackingBatchSampler(
459
+ lengths=processed_dataset["test"]["num_of_concepts"],
460
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
461
+ max_position_embeddings=config.max_position_embeddings,
462
+ drop_last=training_args.dataloader_drop_last,
463
+ seed=training_args.seed,
464
+ )
465
+ per_device_eval_batch_size = 1
466
+ else:
467
+ batch_sampler = None
583
468
  test_dataloader = DataLoader(
584
469
  dataset=processed_dataset["test"],
585
- batch_size=training_args.per_device_eval_batch_size,
470
+ batch_size=per_device_eval_batch_size,
586
471
  num_workers=training_args.dataloader_num_workers,
587
472
  collate_fn=data_collator,
588
473
  pin_memory=training_args.dataloader_pin_memory,
474
+ batch_sampler=batch_sampler,
589
475
  )
590
476
  do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
591
477
 
592
478
 
593
479
  def retrain_with_full_set(
480
+ trainer_class,
594
481
  model_args: ModelArguments,
595
482
  training_args: TrainingArguments,
483
+ cehrgpt_args: CehrGPTArguments,
596
484
  tokenizer: CehrGptTokenizer,
597
485
  dataset: DatasetDict,
598
486
  data_collator: CehrGptDataCollator,
@@ -607,9 +495,11 @@ def retrain_with_full_set(
607
495
  and state information.
608
496
 
609
497
  Args:
498
+ trainer_class: Trainer or its subclass
610
499
  model_args (ModelArguments): Model configuration and hyperparameters.
611
500
  training_args (TrainingArguments): Training configuration, including output directory,
612
501
  evaluation strategy, and other training parameters.
502
+ cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
613
503
  tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
614
504
  dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
615
505
  data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
@@ -628,8 +518,8 @@ def retrain_with_full_set(
628
518
  # Disable evaluation
629
519
  training_args.evaluation_strategy = "no"
630
520
  checkpoint = get_last_hf_checkpoint(training_args)
631
- final_trainer = Trainer(
632
- model=model_init(model_args, training_args, tokenizer),
521
+ final_trainer = trainer_class(
522
+ model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
633
523
  data_collator=data_collator,
634
524
  args=training_args,
635
525
  train_dataset=full_dataset,
@@ -683,15 +573,15 @@ def do_predict(
683
573
  test_losses = []
684
574
  with torch.no_grad():
685
575
  for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
686
- person_ids = batch.pop("person_id").numpy().squeeze().astype(int)
687
- index_dates = (
688
- map(
689
- datetime.fromtimestamp,
690
- batch.pop("index_date").numpy().squeeze(axis=-1).tolist(),
691
- )
692
- if "index_date" in batch
693
- else None
694
- )
576
+ person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
577
+ if person_ids.ndim == 0:
578
+ person_ids = np.asarray([person_ids])
579
+
580
+ index_dates = batch.pop("index_date").numpy().squeeze()
581
+ if index_dates.ndim == 0:
582
+ index_dates = np.asarray([index_dates])
583
+ index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
584
+
695
585
  batch = {k: v.to(device) for k, v in batch.items()}
696
586
  # Forward pass
697
587
  output = model(**batch, output_attentions=False, output_hidden_states=False)
@@ -699,17 +589,25 @@ def do_predict(
699
589
 
700
590
  # Collect logits and labels for prediction
701
591
  logits = output.logits.float().cpu().numpy().squeeze()
592
+ if logits.ndim == 0:
593
+ logits = np.asarray([logits])
594
+ probabilities = sigmoid(logits)
595
+
702
596
  labels = (
703
- batch["classifier_label"].float().cpu().numpy().squeeze().astype(bool)
597
+ batch["classifier_label"].float().cpu().numpy().astype(bool).squeeze()
704
598
  )
705
- probabilities = sigmoid(logits)
599
+ if labels.ndim == 0:
600
+ labels = np.asarray([labels])
601
+
706
602
  # Save predictions to parquet file
707
603
  test_prediction_pd = pd.DataFrame(
708
604
  {
709
605
  "subject_id": person_ids,
710
606
  "prediction_time": index_dates,
711
- "boolean_prediction_probability": probabilities,
712
- "boolean_prediction": logits,
607
+ "predicted_boolean_probability": probabilities,
608
+ "predicted_boolean_value": pd.Series(
609
+ [None] * len(person_ids), dtype=bool
610
+ ),
713
611
  "boolean_value": labels,
714
612
  }
715
613
  )
@@ -723,7 +621,7 @@ def do_predict(
723
621
  # Compute metrics and save results
724
622
  metrics = compute_metrics(
725
623
  references=test_prediction_pd.boolean_value,
726
- probs=test_prediction_pd.boolean_prediction_probability,
624
+ probs=test_prediction_pd.predicted_boolean_probability,
727
625
  )
728
626
  metrics["test_loss"] = np.mean(test_losses)
729
627