cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,12 @@
1
1
  import os
2
+ from functools import partial
2
3
  from typing import Optional, Union
3
4
 
5
+ import numpy as np
4
6
  import torch
7
+ import torch.distributed as dist
5
8
  from cehrbert.data_generators.hf_data_generator.meds_utils import (
9
+ CacheFileCollector,
6
10
  create_dataset_from_meds_reader,
7
11
  )
8
12
  from cehrbert.runners.hf_runner_argument_dataclass import (
@@ -16,22 +20,42 @@ from cehrbert.runners.runner_util import (
16
20
  load_parquet_as_dataset,
17
21
  )
18
22
  from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
19
- from transformers import AutoConfig, Trainer, TrainingArguments, set_seed
23
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
24
+ from transformers.trainer_utils import is_main_process
20
25
  from transformers.utils import is_flash_attn_2_available, logging
21
26
 
22
27
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
23
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
28
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
29
+ CehrGptDataCollator,
30
+ SamplePackingCehrGptDataCollator,
31
+ )
24
32
  from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
25
33
  from cehrgpt.models.config import CEHRGPTConfig
26
34
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
27
35
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
28
36
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
37
+ from cehrgpt.runners.data_utils import get_torch_dtype
29
38
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
30
39
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
40
+ from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
31
41
 
32
42
  LOG = logging.get_logger("transformers")
33
43
 
34
44
 
45
+ class CustomEarlyStoppingCallback(EarlyStoppingCallback):
46
+ def check_metric_value(self, args, state, control, metric_value):
47
+ # best_metric is set by code for load_best_model
48
+ operator = np.greater if args.greater_is_better else np.less
49
+ if state.best_metric is None or (
50
+ operator(metric_value, state.best_metric)
51
+ and abs(metric_value - state.best_metric) / state.best_metric
52
+ > self.early_stopping_threshold
53
+ ):
54
+ self.early_stopping_patience_counter = 0
55
+ else:
56
+ self.early_stopping_patience_counter += 1
57
+
58
+
35
59
  def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
36
60
  # Try to load the pretrained tokenizer
37
61
  try:
@@ -48,6 +72,36 @@ def load_and_create_tokenizer(
48
72
  cehrgpt_args: CehrGPTArguments,
49
73
  dataset: Optional[Union[Dataset, DatasetDict]] = None,
50
74
  ) -> CehrGptTokenizer:
75
+
76
+ concept_name_mapping = {}
77
+ allowed_motor_codes = list()
78
+ if cehrgpt_args.concept_dir:
79
+ import pandas as pd
80
+ from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
81
+ from meds.schema import death_code
82
+
83
+ LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
84
+ concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
85
+ LOG.info(
86
+ "Creating concept name mapping and motor_time_to_event_codes from disk at %s",
87
+ cehrgpt_args.concept_dir,
88
+ )
89
+ for row in concept_pd.itertuples():
90
+ concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
91
+ row, "concept_name"
92
+ )
93
+ if (
94
+ cehrgpt_args.include_motor_time_to_event
95
+ and getattr(row, "domain_id")
96
+ in ["Condition", "Procedure", "Drug", "Visit"]
97
+ and getattr(row, "standard_concept") == "S"
98
+ ):
99
+ allowed_motor_codes.append(str(getattr(row, "concept_id")))
100
+ LOG.info(
101
+ "Adding death codes for MOTOR TTE predictions: %s",
102
+ [DEATH_TOKEN, death_code],
103
+ )
104
+ allowed_motor_codes.extend([DEATH_TOKEN, death_code])
51
105
  # Try to load the pretrained tokenizer
52
106
  tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
53
107
  try:
@@ -59,13 +113,24 @@ def load_and_create_tokenizer(
59
113
  f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
60
114
  f"Tried to create the tokenizer, however the dataset is not provided."
61
115
  )
116
+ LOG.info("Started training the tokenizer ...")
62
117
  tokenizer = CehrGptTokenizer.train_tokenizer(
63
118
  dataset,
64
- {},
119
+ concept_name_mapping,
65
120
  data_args,
66
121
  PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
122
+ allowed_motor_codes if cehrgpt_args.include_motor_time_to_event else None,
123
+ (
124
+ cehrgpt_args.num_motor_tasks
125
+ if cehrgpt_args.include_motor_time_to_event
126
+ else None
127
+ ),
128
+ apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
129
+ min_prevalence=cehrgpt_args.min_prevalence,
67
130
  )
131
+ LOG.info("Finished training the tokenizer ...")
68
132
  tokenizer.save_pretrained(tokenizer_abspath)
133
+ LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
69
134
 
70
135
  return tokenizer
71
136
 
@@ -73,13 +138,12 @@ def load_and_create_tokenizer(
73
138
  def load_and_create_model(
74
139
  model_args: ModelArguments,
75
140
  cehrgpt_args: CehrGPTArguments,
76
- training_args: TrainingArguments,
77
141
  tokenizer: CehrGptTokenizer,
78
142
  ) -> CEHRGPT2LMHeadModel:
79
143
  attn_implementation = (
80
144
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
81
145
  )
82
- torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
146
+ torch_dtype = get_torch_dtype(model_args.torch_dtype)
83
147
  model_abspath = os.path.expanduser(model_args.model_name_or_path)
84
148
  if cehrgpt_args.continue_pretrain:
85
149
  try:
@@ -120,6 +184,9 @@ def load_and_create_model(
120
184
  pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
121
185
  else:
122
186
  pretrained_embedding_dim = model_args.hidden_size
187
+
188
+ model_args_cehrgpt = model_args.as_dict()
189
+ model_args_cehrgpt.pop("attn_implementation")
123
190
  model_config = CEHRGPTConfig(
124
191
  vocab_size=tokenizer.vocab_size,
125
192
  value_vocab_size=tokenizer.value_vocab_size,
@@ -131,15 +198,28 @@ def load_and_create_model(
131
198
  attn_implementation=attn_implementation,
132
199
  causal_sfm=cehrgpt_args.causal_sfm,
133
200
  demographics_size=cehrgpt_args.demographics_size,
201
+ next_token_prediction_loss_weight=cehrgpt_args.next_token_prediction_loss_weight,
134
202
  lab_token_penalty=cehrgpt_args.lab_token_penalty,
135
203
  lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
204
+ value_prediction_loss_weight=cehrgpt_args.value_prediction_loss_weight,
136
205
  entropy_penalty=cehrgpt_args.entropy_penalty,
137
206
  entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
138
207
  n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
139
208
  use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
140
209
  pretrained_embedding_dim=pretrained_embedding_dim,
141
- **model_args.as_dict(),
210
+ sample_packing_max_positions=(
211
+ cehrgpt_args.max_tokens_per_batch
212
+ if cehrgpt_args.sample_packing
213
+ else model_args.max_position_embeddings
214
+ ),
215
+ include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
216
+ motor_tte_vocab_size=tokenizer.motor_tte_vocab_size,
217
+ motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
218
+ motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
219
+ ve_token_id=tokenizer.ve_token_id,
220
+ **model_args_cehrgpt,
142
221
  )
222
+
143
223
  model = CEHRGPT2LMHeadModel(model_config)
144
224
  if tokenizer.pretrained_token_ids:
145
225
  model.cehrgpt.update_pretrained_embeddings(
@@ -156,6 +236,11 @@ def load_and_create_model(
156
236
  def main():
157
237
  cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
158
238
 
239
+ if cehrgpt_args.sample_packing and data_args.streaming:
240
+ raise RuntimeError(
241
+ f"sample_packing is not supported when streaming is enabled, please set streaming to False"
242
+ )
243
+
159
244
  if data_args.streaming:
160
245
  # This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
161
246
  # This happens only when streaming is enabled
@@ -165,6 +250,8 @@ def main():
165
250
  training_args.dataloader_num_workers = 0
166
251
  training_args.dataloader_prefetch_factor = None
167
252
 
253
+ processed_dataset: Optional[DatasetDict] = None
254
+ cache_file_collector = CacheFileCollector()
168
255
  prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
169
256
  if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
170
257
  LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
@@ -200,118 +287,160 @@ def main():
200
287
  )
201
288
  cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
202
289
  else:
203
- # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
204
- if data_args.is_data_in_meds:
205
- meds_extension_path = get_meds_extension_path(
206
- data_folder=data_args.data_folder,
207
- dataset_prepared_path=data_args.dataset_prepared_path,
208
- )
209
- try:
210
- LOG.info(
211
- "Trying to load the MEDS extension from disk at %s...",
212
- meds_extension_path,
290
+ # Only run tokenization and data transformation in the main process in torch distributed training
291
+ # otherwise the multiple processes will create tokenizers at the same time
292
+ if is_main_process(training_args.local_rank):
293
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
294
+ if data_args.is_data_in_meds:
295
+ meds_extension_path = get_meds_extension_path(
296
+ data_folder=data_args.data_folder,
297
+ dataset_prepared_path=data_args.dataset_prepared_path,
213
298
  )
214
- dataset = load_from_disk(meds_extension_path)
215
- if data_args.streaming:
216
- if isinstance(dataset, DatasetDict):
217
- dataset = {
218
- k: v.to_iterable_dataset(
299
+ try:
300
+ LOG.info(
301
+ "Trying to load the MEDS extension from disk at %s...",
302
+ meds_extension_path,
303
+ )
304
+ dataset = load_from_disk(meds_extension_path)
305
+ if data_args.streaming:
306
+ if isinstance(dataset, DatasetDict):
307
+ dataset = {
308
+ k: v.to_iterable_dataset(
309
+ num_shards=training_args.dataloader_num_workers
310
+ )
311
+ for k, v in dataset.items()
312
+ }
313
+ else:
314
+ dataset = dataset.to_iterable_dataset(
219
315
  num_shards=training_args.dataloader_num_workers
220
316
  )
221
- for k, v in dataset.items()
222
- }
223
- else:
224
- dataset = dataset.to_iterable_dataset(
225
- num_shards=training_args.dataloader_num_workers
226
- )
227
- except FileNotFoundError as e:
228
- LOG.exception(e)
229
- dataset = create_dataset_from_meds_reader(
230
- data_args=data_args,
231
- dataset_mappings=[
232
- MedToCehrGPTDatasetMapping(
233
- data_args=data_args,
234
- is_pretraining=True,
235
- include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
317
+ except FileNotFoundError as e:
318
+ LOG.warning(e)
319
+ dataset = create_dataset_from_meds_reader(
320
+ data_args=data_args,
321
+ dataset_mappings=[
322
+ MedToCehrGPTDatasetMapping(
323
+ data_args=data_args,
324
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
325
+ )
326
+ ],
327
+ cache_file_collector=cache_file_collector,
328
+ )
329
+ if not data_args.streaming:
330
+ dataset.save_to_disk(str(meds_extension_path))
331
+ stats = dataset.cleanup_cache_files()
332
+ LOG.info(
333
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
334
+ stats,
236
335
  )
237
- ],
336
+ # Clean up the files created from the data generator
337
+ cache_file_collector.remove_cache_files()
338
+ dataset = load_from_disk(str(meds_extension_path))
339
+ else:
340
+ # Load the dataset from the parquet files
341
+ dataset = load_parquet_as_dataset(
342
+ os.path.expanduser(data_args.data_folder),
343
+ split="train",
344
+ streaming=data_args.streaming,
238
345
  )
239
- if not data_args.streaming:
240
- dataset.save_to_disk(str(meds_extension_path))
241
- stats = dataset.cleanup_cache_files()
242
- LOG.info(
243
- "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
244
- stats,
346
+ # If streaming is enabled, we need to manually split the data into train/val
347
+ if data_args.streaming and data_args.validation_split_num:
348
+ dataset = dataset.shuffle(
349
+ buffer_size=10_000, seed=training_args.seed
245
350
  )
246
- dataset = load_from_disk(str(meds_extension_path))
247
- else:
248
- # Load the dataset from the parquet files
249
- dataset = load_parquet_as_dataset(
250
- os.path.expanduser(data_args.data_folder),
251
- split="train",
252
- streaming=data_args.streaming,
351
+ train_set = dataset.skip(data_args.validation_split_num)
352
+ val_set = dataset.take(data_args.validation_split_num)
353
+ dataset = DatasetDict({"train": train_set, "validation": val_set})
354
+ elif data_args.validation_split_percentage:
355
+ dataset = dataset.train_test_split(
356
+ test_size=data_args.validation_split_percentage,
357
+ seed=training_args.seed,
358
+ )
359
+ dataset = DatasetDict(
360
+ {"train": dataset["train"], "validation": dataset["test"]}
361
+ )
362
+ else:
363
+ raise RuntimeError(
364
+ f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
365
+ f"defined, otherwise validation_split_percentage needs to be provided. "
366
+ f"The current values are:\n"
367
+ f"validation_split_percentage: {data_args.validation_split_percentage}\n"
368
+ f"validation_split_num: {data_args.validation_split_num}\n"
369
+ f"streaming: {data_args.streaming}"
370
+ )
371
+
372
+ # Create the CEHR-GPT tokenizer if it's not available in the output folder
373
+ cehrgpt_tokenizer = load_and_create_tokenizer(
374
+ data_args=data_args,
375
+ model_args=model_args,
376
+ cehrgpt_args=cehrgpt_args,
377
+ dataset=dataset,
253
378
  )
254
- # If streaming is enabled, we need to manually split the data into train/val
255
- if data_args.streaming and data_args.validation_split_num:
256
- dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
257
- train_set = dataset.skip(data_args.validation_split_num)
258
- val_set = dataset.take(data_args.validation_split_num)
259
- dataset = DatasetDict({"train": train_set, "test": val_set})
260
- elif data_args.validation_split_percentage:
261
- dataset = dataset.train_test_split(
262
- test_size=data_args.validation_split_percentage,
263
- seed=training_args.seed,
264
- )
265
- else:
266
- raise RuntimeError(
267
- f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
268
- f"defined, otherwise validation_split_percentage needs to be provided. "
269
- f"The current values are:\n"
270
- f"validation_split_percentage: {data_args.validation_split_percentage}\n"
271
- f"validation_split_num: {data_args.validation_split_num}\n"
272
- f"streaming: {data_args.streaming}"
273
- )
274
379
 
275
- # Create the CEHR-GPT tokenizer if it's not available in the output folder
276
- cehrgpt_tokenizer = load_and_create_tokenizer(
277
- data_args=data_args,
278
- model_args=model_args,
279
- cehrgpt_args=cehrgpt_args,
280
- dataset=dataset,
281
- )
282
- # Retrain the tokenizer in case we want to pretrain the model further using different datasets
283
- if cehrgpt_args.expand_tokenizer:
284
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
285
- try:
286
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
287
- except Exception:
288
- cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
289
- cehrgpt_tokenizer=cehrgpt_tokenizer,
290
- dataset=dataset["train"],
291
- data_args=data_args,
292
- concept_name_mapping={},
293
- pretrained_concept_embedding_model=PretrainedEmbeddings(
294
- cehrgpt_args.pretrained_embedding_path
295
- ),
296
- )
297
- cehrgpt_tokenizer.save_pretrained(
298
- os.path.expanduser(training_args.output_dir)
380
+ # Retrain the tokenizer in case we want to pretrain the model further using different datasets
381
+ if cehrgpt_args.expand_tokenizer:
382
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
383
+ try:
384
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
385
+ new_tokenizer_path
386
+ )
387
+ except Exception:
388
+ cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
389
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
390
+ dataset=dataset["train"],
391
+ data_args=data_args,
392
+ concept_name_mapping={},
393
+ pretrained_concept_embedding_model=PretrainedEmbeddings(
394
+ cehrgpt_args.pretrained_embedding_path
395
+ ),
396
+ apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
397
+ min_prevalence=cehrgpt_args.min_prevalence,
398
+ )
399
+ cehrgpt_tokenizer.save_pretrained(
400
+ os.path.expanduser(training_args.output_dir)
401
+ )
402
+
403
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
404
+ if not data_args.streaming:
405
+ all_columns = dataset["train"].column_names
406
+ if "visit_concept_ids" in all_columns:
407
+ dataset = dataset.remove_columns(["visit_concept_ids"])
408
+
409
+ # sort the patient features chronologically and tokenize the data
410
+ processed_dataset = create_cehrgpt_pretraining_dataset(
411
+ dataset=dataset,
412
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
413
+ data_args=data_args,
414
+ cache_file_collector=cache_file_collector,
415
+ )
416
+ # only save the data to the disk if it is not streaming
417
+ if not data_args.streaming:
418
+ processed_dataset.save_to_disk(str(prepared_ds_path))
419
+ stats = processed_dataset.cleanup_cache_files()
420
+ LOG.info(
421
+ "Clean up the cached files for the cehrgpt pretraining dataset: %s",
422
+ stats,
299
423
  )
424
+ cache_file_collector.remove_cache_files()
425
+
426
+ # After main-process-only operations, synchronize all processes to ensure consistency
427
+ if dist.is_available() and dist.is_initialized():
428
+ dist.barrier()
300
429
 
301
- # sort the patient features chronologically and tokenize the data
302
- processed_dataset = create_cehrgpt_pretraining_dataset(
303
- dataset=dataset, cehrgpt_tokenizer=cehrgpt_tokenizer, data_args=data_args
430
+ # Loading tokenizer in all processes in torch distributed training
431
+ tokenizer_name_or_path = os.path.expanduser(
432
+ training_args.output_dir
433
+ if cehrgpt_args.expand_tokenizer
434
+ else model_args.tokenizer_name_or_path
304
435
  )
305
- # only save the data to the disk if it is not streaming
436
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
437
+ # Load the dataset from disk again to in torch distributed training
306
438
  if not data_args.streaming:
307
- processed_dataset.save_to_disk(str(prepared_ds_path))
308
- stats = processed_dataset.cleanup_cache_files()
309
- LOG.info(
310
- "Clean up the cached files for the cehrgpt pretraining dataset: %s",
311
- stats,
312
- )
313
439
  processed_dataset = load_from_disk(str(prepared_ds_path))
314
440
 
441
+ if processed_dataset is None:
442
+ raise RuntimeError("The processed dataset cannot be None")
443
+
315
444
  def filter_func(examples):
316
445
  if cehrgpt_args.drop_long_sequences:
317
446
  return [
@@ -339,9 +468,11 @@ def main():
339
468
  else:
340
469
  processed_dataset = processed_dataset.filter(filter_func, **filter_args)
341
470
 
342
- model = load_and_create_model(
343
- model_args, cehrgpt_args, training_args, cehrgpt_tokenizer
344
- )
471
+ model = load_and_create_model(model_args, cehrgpt_args, cehrgpt_tokenizer)
472
+
473
+ # Try to update motor tte vocab size if the new configuration is different from the existing one
474
+ if cehrgpt_args.include_motor_time_to_event:
475
+ model.update_motor_tte_vocab_size(cehrgpt_tokenizer.motor_tte_vocab_size)
345
476
 
346
477
  # Expand tokenizer to adapt to the new pretraining dataset
347
478
  if model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
@@ -369,22 +500,67 @@ def main():
369
500
  # Set seed before initializing model.
370
501
  set_seed(training_args.seed)
371
502
 
372
- if not data_args.streaming:
503
+ if not data_args.streaming and not cehrgpt_args.sample_packing:
373
504
  processed_dataset.set_format("pt")
374
505
 
375
- trainer = Trainer(
506
+ callbacks = []
507
+ if cehrgpt_args.use_early_stopping:
508
+ callbacks.append(
509
+ CustomEarlyStoppingCallback(
510
+ model_args.early_stopping_patience,
511
+ cehrgpt_args.early_stopping_threshold,
512
+ )
513
+ )
514
+
515
+ if cehrgpt_args.sample_packing:
516
+ trainer_class = partial(
517
+ SamplePackingTrainer,
518
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
519
+ max_position_embeddings=model_args.max_position_embeddings,
520
+ train_lengths=processed_dataset["train"]["num_of_concepts"],
521
+ validation_lengths=(
522
+ processed_dataset["validation"]
523
+ if "validation" in processed_dataset
524
+ else processed_dataset["test"]
525
+ )["num_of_concepts"],
526
+ )
527
+ training_args.per_device_train_batch_size = 1
528
+ training_args.per_device_eval_batch_size = 1
529
+ data_collator_fn = partial(
530
+ SamplePackingCehrGptDataCollator,
531
+ cehrgpt_args.max_tokens_per_batch,
532
+ model_args.max_position_embeddings,
533
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
534
+ )
535
+ else:
536
+ trainer_class = Trainer
537
+ data_collator_fn = CehrGptDataCollator
538
+
539
+ trainer = trainer_class(
376
540
  model=model,
377
- data_collator=CehrGptDataCollator(
541
+ data_collator=data_collator_fn(
378
542
  tokenizer=cehrgpt_tokenizer,
379
- max_length=model_args.max_position_embeddings,
543
+ max_length=(
544
+ cehrgpt_args.max_tokens_per_batch
545
+ if cehrgpt_args.sample_packing
546
+ else model_args.max_position_embeddings
547
+ ),
380
548
  shuffle_records=data_args.shuffle_records,
381
549
  include_ttv_prediction=model_args.include_ttv_prediction,
382
550
  use_sub_time_tokenization=model_args.use_sub_time_tokenization,
383
551
  include_values=model_args.include_values,
552
+ include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
553
+ motor_tte_vocab_size=model.config.motor_tte_vocab_size,
554
+ motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
384
555
  ),
385
556
  train_dataset=processed_dataset["train"],
386
- eval_dataset=processed_dataset["test"],
557
+ eval_dataset=(
558
+ processed_dataset["validation"]
559
+ if "validation" in processed_dataset
560
+ else processed_dataset["test"]
561
+ ),
387
562
  args=training_args,
563
+ callbacks=callbacks,
388
564
  )
389
565
 
390
566
  checkpoint = None
@@ -6,6 +6,12 @@ from typing import List, Optional
6
6
  class CehrGPTArguments:
7
7
  """Arguments pertaining to what data we are going to input our model for training and eval."""
8
8
 
9
+ tokenized_full_dataset_path: Optional[str] = dataclasses.field(
10
+ default=None,
11
+ metadata={
12
+ "help": "The path to the tokenized dataset created for the full population"
13
+ },
14
+ )
9
15
  include_inpatient_hour_token: Optional[bool] = dataclasses.field(
10
16
  default=True,
11
17
  metadata={"help": "Include inpatient hour token"},
@@ -115,6 +121,9 @@ class CehrGPTArguments:
115
121
  "help": "The lower bound of the learning rate range for hyperparameter tuning."
116
122
  },
117
123
  )
124
+ next_token_prediction_loss_weight: float = dataclasses.field(
125
+ default=1.0, metadata={"help": "The weight of the next token prediction loss"}
126
+ )
118
127
  lab_token_penalty: Optional[bool] = dataclasses.field(
119
128
  default=False,
120
129
  metadata={
@@ -125,6 +134,10 @@ class CehrGPTArguments:
125
134
  default=1.0,
126
135
  metadata={"help": "lab_token_loss_weight penalty co-efficient"},
127
136
  )
137
+ value_prediction_loss_weight: Optional[float] = dataclasses.field(
138
+ default=1.0,
139
+ metadata={"help": "The weight of the value prediction loss"},
140
+ )
128
141
  entropy_penalty: Optional[bool] = dataclasses.field(
129
142
  default=False,
130
143
  metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
@@ -139,3 +152,80 @@ class CehrGPTArguments:
139
152
  "help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
140
153
  },
141
154
  )
155
+ meds_repartition: Optional[bool] = dataclasses.field(
156
+ default=False,
157
+ metadata={
158
+ "help": "A flag to indicate whether we want to repartition the meds train tune sets"
159
+ },
160
+ )
161
+ use_early_stopping: Optional[bool] = dataclasses.field(
162
+ default=True,
163
+ metadata={"help": "A flag to indicate whether we want to use early stopping."},
164
+ )
165
+ early_stopping_threshold: Optional[float] = dataclasses.field(
166
+ default=0.01,
167
+ metadata={
168
+ "help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
169
+ },
170
+ )
171
+ sample_packing: Optional[bool] = dataclasses.field(
172
+ default=False,
173
+ metadata={
174
+ "help": "A flag to indicate whether we want to use sample packing for efficient training."
175
+ },
176
+ )
177
+ max_tokens_per_batch: int = dataclasses.field(
178
+ default=16384, metadata={"help": "Maximum number of tokens in each batch"}
179
+ )
180
+ add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
181
+ default=False,
182
+ metadata={
183
+ "help": "A flag to indicate whether we want to add end token in sample packing"
184
+ },
185
+ )
186
+ include_motor_time_to_event: Optional[bool] = dataclasses.field(
187
+ default=False,
188
+ metadata={
189
+ "help": "A flag to indicate whether we want to include the motor time to events"
190
+ },
191
+ )
192
+ num_motor_tasks: Optional[int] = dataclasses.field(
193
+ default=10000,
194
+ metadata={"help": "The number of max MOTOR tasks"},
195
+ )
196
+ motor_time_to_event_weight: Optional[float] = dataclasses.field(
197
+ default=1.0,
198
+ metadata={"help": "The MOTOR time to event loss weight"},
199
+ )
200
+ motor_num_time_pieces: Optional[int] = dataclasses.field(
201
+ default=8,
202
+ metadata={
203
+ "help": "The number of times each motor_num_time_pieces piece has to be"
204
+ },
205
+ )
206
+ concept_dir: Optional[str] = dataclasses.field(
207
+ default=None,
208
+ metadata={"help": "The directory where the concept data is stored."},
209
+ )
210
+ average_over_sequence: bool = dataclasses.field(
211
+ default=False,
212
+ metadata={"help": "Whether or not to average tokens per sequence"},
213
+ )
214
+ apply_entropy_filter: Optional[bool] = dataclasses.field(
215
+ default=False,
216
+ metadata={"help": "A flag to indicate whether we want to use entropy filter."},
217
+ )
218
+ min_prevalence: Optional[float] = dataclasses.field(
219
+ default=1 / 1000,
220
+ metadata={"help": "The min_prevalence to keep the concepts in the tokenizer"},
221
+ )
222
+ class_weights: Optional[List[int]] = dataclasses.field(
223
+ default=None,
224
+ metadata={"help": "The class weights for training"},
225
+ )
226
+ negative_sampling_probability: Optional[float] = dataclasses.field(
227
+ default=None,
228
+ metadata={
229
+ "help": "The probability of negative samples will be included in the training data"
230
+ },
231
+ )