cehrgpt 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,10 @@
1
+ import os
2
+ from datetime import datetime
3
+ from typing import Dict, List, Optional, Union
4
+
1
5
  import numpy as np
6
+ import polars as pl
7
+ import torch
2
8
  from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
3
9
  from cehrbert.data_generators.hf_data_generator.meds_utils import (
4
10
  create_dataset_from_meds_reader,
@@ -12,12 +18,31 @@ from datasets import DatasetDict, concatenate_datasets, load_from_disk
12
18
  from transformers import TrainingArguments
13
19
  from transformers.utils import logging
14
20
 
15
- from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
21
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import (
22
+ ExtractTokenizedSequenceDataMapping,
23
+ MedToCehrGPTDatasetMapping,
24
+ )
16
25
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
17
26
 
18
27
  LOG = logging.get_logger("transformers")
19
28
 
20
29
 
30
+ def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
31
+ if torch_dtype and hasattr(torch, torch_dtype):
32
+ return getattr(torch, torch_dtype)
33
+ return torch.float
34
+
35
+
36
+ def data_collate_fn(features, model_type: torch.dtype, collator):
37
+ batch = collator(features)
38
+ if model_type != torch.float32:
39
+ for key, value in batch.items():
40
+ # Only convert float32 tensors to bfloat16
41
+ if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
42
+ batch[key] = value.to(model_type)
43
+ return batch
44
+
45
+
21
46
  def prepare_finetune_dataset(
22
47
  data_args: DataTrainingArguments,
23
48
  training_args: TrainingArguments,
@@ -219,7 +244,7 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
219
244
  )
220
245
 
221
246
  # Generate splits
222
- train_set = filter_by_patient_ids(train_patient_ids)
247
+ train_set = filter_by_patient_ids(train_patient_ids).shuffle(seed=seed)
223
248
  validation_set = filter_by_patient_ids(val_patient_ids)
224
249
  if test_set is None:
225
250
  test_set = filter_by_patient_ids(test_patient_ids)
@@ -241,3 +266,93 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
241
266
  )
242
267
 
243
268
  return train_set, validation_set, test_set
269
+
270
+
271
+ def extract_cohort_sequences(
272
+ data_args: DataTrainingArguments,
273
+ cehrgpt_args: CehrGPTArguments,
274
+ cache_file_collector: CacheFileCollector,
275
+ ) -> DatasetDict:
276
+ """
277
+ Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
278
+
279
+ based on the provided cohort Parquet files and observation window constraints.
280
+
281
+ This function performs the following steps:
282
+ 1. Loads cohort definitions from Parquet files located in `data_args.cohort_folder`.
283
+ 2. Renames relevant columns if the data originates from a Meds format.
284
+ 3. Filters a pre-tokenized dataset (loaded from `cehrgpt_args.tokenized_full_dataset_path`)
285
+ to include only patients present in the cohort.
286
+ 4. Aggregates each person's index date and label into a mapping.
287
+ 5. Checks for consistency to ensure all cohort person_ids are present in the tokenized dataset.
288
+ 6. Applies a transformation (`ExtractTokenizedSequenceDataMapping`) to generate
289
+ observation-window-constrained patient sequences.
290
+ 7. Caches both the filtered and processed datasets using the provided `cache_file_collector`.
291
+
292
+ Args:
293
+ data_args (DataTrainingArguments): Configuration parameters for data processing,
294
+ including cohort folder, observation window, batch size, and parallelism.
295
+ cehrgpt_args (CehrGPTArguments): Contains paths to pre-tokenized datasets and CEHR-GPT-specific arguments.
296
+ cache_file_collector (CacheFileCollector): Utility to register and manage dataset cache files.
297
+
298
+ Returns:
299
+ DatasetDict: A Hugging Face `DatasetDict` containing the processed datasets (e.g., train/validation/test),
300
+ where each entry includes sequences filtered and truncated by the observation window.
301
+
302
+ Raises:
303
+ RuntimeError: If any `person_id` in the cohort is missing from the tokenized dataset.
304
+ """
305
+
306
+ cohort = pl.read_parquet(os.path.join(data_args.cohort_folder, "*.parquet"))
307
+ if data_args.is_data_in_meds:
308
+ cohort = cohort.rename(
309
+ mapping={
310
+ "prediction_time": "index_date",
311
+ "subject_id": "person_id",
312
+ }
313
+ )
314
+ all_person_ids = cohort["person_id"].unique().to_list()
315
+ # data_args.observation_window
316
+ tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
317
+ filtered_tokenized_dataset = tokenized_dataset.filter(
318
+ lambda batch: [person_id in all_person_ids for person_id in batch["person_id"]],
319
+ batched=True,
320
+ batch_size=data_args.preprocessing_batch_size,
321
+ num_proc=data_args.preprocessing_num_workers,
322
+ )
323
+ person_index_date_agg = cohort.group_by("person_id").agg(
324
+ pl.struct("index_date", "label").alias("index_date_label")
325
+ )
326
+ # Convert to dictionary
327
+ person_index_date_map: Dict[int, List[datetime]] = dict(
328
+ zip(
329
+ person_index_date_agg["person_id"].to_list(),
330
+ person_index_date_agg["index_date_label"].to_list(),
331
+ )
332
+ )
333
+ LOG.info(f"person_index_date_agg: {person_index_date_agg}")
334
+ tokenized_person_ids = []
335
+ for _, dataset in filtered_tokenized_dataset.items():
336
+ tokenized_person_ids.extend(dataset["person_id"])
337
+ missing_person_ids = [
338
+ person_id
339
+ for person_id in person_index_date_map.keys()
340
+ if person_id not in tokenized_person_ids
341
+ ]
342
+ if missing_person_ids:
343
+ raise RuntimeError(
344
+ f"There are {len(missing_person_ids)} missing in the tokenized dataset. "
345
+ f"The list contains: {missing_person_ids}"
346
+ )
347
+ processed_dataset = filtered_tokenized_dataset.map(
348
+ ExtractTokenizedSequenceDataMapping(
349
+ person_index_date_map, data_args.observation_window
350
+ ).batch_transform,
351
+ batched=True,
352
+ batch_size=data_args.preprocessing_batch_size,
353
+ num_proc=data_args.preprocessing_num_workers,
354
+ remove_columns=filtered_tokenized_dataset["train"].column_names,
355
+ )
356
+ cache_file_collector.add_cache_files(filtered_tokenized_dataset)
357
+ cache_file_collector.add_cache_files(processed_dataset)
358
+ return processed_dataset
@@ -50,7 +50,11 @@ from cehrgpt.models.hf_cehrgpt import (
50
50
  )
51
51
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
52
52
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
53
- from cehrgpt.runners.data_utils import prepare_finetune_dataset
53
+ from cehrgpt.runners.data_utils import (
54
+ extract_cohort_sequences,
55
+ get_torch_dtype,
56
+ prepare_finetune_dataset,
57
+ )
54
58
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
55
59
  from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
56
60
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
@@ -142,11 +146,10 @@ def load_finetuned_model(
142
146
  raise ValueError(
143
147
  f"finetune_model_type can be one of the following types {FineTuneModelType.POOLING.value}"
144
148
  )
145
-
146
149
  attn_implementation = (
147
150
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
148
151
  )
149
- torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
152
+ torch_dtype = get_torch_dtype(model_args.torch_dtype)
150
153
  # Try to create a new model based on the base model
151
154
  try:
152
155
  return finetune_model_cls.from_pretrained(
@@ -161,11 +164,22 @@ def load_finetuned_model(
161
164
  def model_init(
162
165
  model_args: ModelArguments,
163
166
  training_args: TrainingArguments,
167
+ cehrgpt_args: CehrGPTArguments,
164
168
  tokenizer: CehrGptTokenizer,
165
169
  ):
166
170
  model = load_finetuned_model(
167
171
  model_args, training_args, model_args.model_name_or_path
168
172
  )
173
+
174
+ if cehrgpt_args.class_weights:
175
+ model.config.class_weights = cehrgpt_args.class_weights
176
+ LOG.info(f"Setting class_weights to {model.config.class_weights}")
177
+
178
+ # Enable position embeddings when position embeddings are disabled in pre-training
179
+ if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
180
+ LOG.info(f"Enable the position_embeddings")
181
+ model.cehrgpt.enable_position_embeddings()
182
+
169
183
  if model.config.max_position_embeddings < model_args.max_position_embeddings:
170
184
  LOG.info(
171
185
  f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
@@ -175,9 +189,6 @@ def model_init(
175
189
  # Enable include_values when include_values is set to be False during pre-training
176
190
  if model_args.include_values and not model.cehrgpt.include_values:
177
191
  model.cehrgpt.include_values = True
178
- # Enable position embeddings when position embeddings are disabled in pre-training
179
- if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
180
- model.cehrgpt.exclude_position_ids = False
181
192
  # Expand tokenizer to adapt to the finetuning dataset
182
193
  if model.config.vocab_size < tokenizer.vocab_size:
183
194
  model.resize_token_embeddings(tokenizer.vocab_size)
@@ -195,6 +206,7 @@ def model_init(
195
206
  model.cehrgpt.update_pretrained_embeddings(
196
207
  tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
197
208
  )
209
+
198
210
  # Expand value tokenizer to adapt to the fine-tuning dataset
199
211
  if model.config.include_values:
200
212
  if model.config.value_vocab_size < tokenizer.value_vocab_size:
@@ -252,46 +264,55 @@ def main():
252
264
 
253
265
  if processed_dataset is None:
254
266
  if is_main_process(training_args.local_rank):
255
- final_splits = prepare_finetune_dataset(
256
- data_args, training_args, cehrgpt_args, cache_file_collector
257
- )
258
- if cehrgpt_args.expand_tokenizer:
259
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
260
- if tokenizer_exists(new_tokenizer_path):
261
- tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
262
- else:
263
- # Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
264
- # embedded in the pretrained model
265
- pretrained_concept_embedding_model = PretrainedEmbeddings(
266
- cehrgpt_args.pretrained_embedding_path
267
- )
268
- if not pretrained_concept_embedding_model.exists:
269
- pretrained_concept_embedding_model = (
270
- tokenizer.pretrained_concept_embedding_model
267
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
268
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
269
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
270
+ processed_dataset = extract_cohort_sequences(
271
+ data_args, cehrgpt_args, cache_file_collector
272
+ )
273
+ else:
274
+ final_splits = prepare_finetune_dataset(
275
+ data_args, training_args, cehrgpt_args, cache_file_collector
276
+ )
277
+ if cehrgpt_args.expand_tokenizer:
278
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
279
+ if tokenizer_exists(new_tokenizer_path):
280
+ tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
281
+ else:
282
+ # Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
283
+ # embedded in the pretrained model
284
+ pretrained_concept_embedding_model = PretrainedEmbeddings(
285
+ cehrgpt_args.pretrained_embedding_path
271
286
  )
272
- tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
273
- cehrgpt_tokenizer=tokenizer,
274
- dataset=final_splits["train"],
275
- data_args=data_args,
276
- concept_name_mapping={},
277
- pretrained_concept_embedding_model=pretrained_concept_embedding_model,
278
- )
279
- tokenizer.save_pretrained(
280
- os.path.expanduser(training_args.output_dir)
281
- )
282
-
283
- # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
284
- if not data_args.streaming:
285
- all_columns = final_splits["train"].column_names
286
- if "visit_concept_ids" in all_columns:
287
- final_splits = final_splits.remove_columns(["visit_concept_ids"])
288
-
289
- processed_dataset = create_cehrgpt_finetuning_dataset(
290
- dataset=final_splits,
291
- cehrgpt_tokenizer=tokenizer,
292
- data_args=data_args,
293
- cache_file_collector=cache_file_collector,
294
- )
287
+ if not pretrained_concept_embedding_model.exists:
288
+ pretrained_concept_embedding_model = (
289
+ tokenizer.pretrained_concept_embedding_model
290
+ )
291
+ tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
292
+ cehrgpt_tokenizer=tokenizer,
293
+ dataset=final_splits["train"],
294
+ data_args=data_args,
295
+ concept_name_mapping={},
296
+ pretrained_concept_embedding_model=pretrained_concept_embedding_model,
297
+ )
298
+ tokenizer.save_pretrained(
299
+ os.path.expanduser(training_args.output_dir)
300
+ )
301
+
302
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
303
+ if not data_args.streaming:
304
+ all_columns = final_splits["train"].column_names
305
+ if "visit_concept_ids" in all_columns:
306
+ final_splits = final_splits.remove_columns(
307
+ ["visit_concept_ids"]
308
+ )
309
+
310
+ processed_dataset = create_cehrgpt_finetuning_dataset(
311
+ dataset=final_splits,
312
+ cehrgpt_tokenizer=tokenizer,
313
+ data_args=data_args,
314
+ cache_file_collector=cache_file_collector,
315
+ )
295
316
  if not data_args.streaming:
296
317
  processed_dataset.save_to_disk(str(prepared_ds_path))
297
318
  stats = processed_dataset.cleanup_cache_files()
@@ -350,8 +371,7 @@ def main():
350
371
  SamplePackingTrainer,
351
372
  max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
352
373
  max_position_embeddings=config.max_position_embeddings,
353
- train_lengths=processed_dataset["train"]["num_of_concepts"],
354
- validation_lengths=processed_dataset["validation"]["num_of_concepts"],
374
+ negative_sampling_probability=cehrgpt_args.negative_sampling_probability,
355
375
  )
356
376
  training_args.per_device_train_batch_size = 1
357
377
  training_args.per_device_eval_batch_size = 1
@@ -359,6 +379,7 @@ def main():
359
379
  SamplePackingCehrGptDataCollator,
360
380
  cehrgpt_args.max_tokens_per_batch,
361
381
  config.max_position_embeddings,
382
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
362
383
  )
363
384
  else:
364
385
  trainer_class = Trainer
@@ -381,13 +402,14 @@ def main():
381
402
  include_ttv_prediction=False,
382
403
  use_sub_time_tokenization=False,
383
404
  include_demographics=cehrgpt_args.include_demographics,
405
+ add_linear_prob_token=True,
384
406
  )
385
407
 
386
408
  if training_args.do_train:
387
409
  if cehrgpt_args.hyperparameter_tuning:
388
410
  training_args = perform_hyperparameter_search(
389
411
  trainer_class,
390
- partial(model_init, model_args, training_args, tokenizer),
412
+ partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
391
413
  processed_dataset,
392
414
  data_collator,
393
415
  training_args,
@@ -401,6 +423,7 @@ def main():
401
423
  trainer_class,
402
424
  model_args,
403
425
  training_args,
426
+ cehrgpt_args,
404
427
  tokenizer,
405
428
  processed_dataset,
406
429
  data_collator,
@@ -408,7 +431,7 @@ def main():
408
431
  else:
409
432
  # Initialize Trainer for final training on the combined train+val set
410
433
  trainer = trainer_class(
411
- model=model_init(model_args, training_args, tokenizer),
434
+ model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
412
435
  data_collator=data_collator,
413
436
  args=training_args,
414
437
  train_dataset=processed_dataset["train"],
@@ -457,6 +480,7 @@ def retrain_with_full_set(
457
480
  trainer_class,
458
481
  model_args: ModelArguments,
459
482
  training_args: TrainingArguments,
483
+ cehrgpt_args: CehrGPTArguments,
460
484
  tokenizer: CehrGptTokenizer,
461
485
  dataset: DatasetDict,
462
486
  data_collator: CehrGptDataCollator,
@@ -475,6 +499,7 @@ def retrain_with_full_set(
475
499
  model_args (ModelArguments): Model configuration and hyperparameters.
476
500
  training_args (TrainingArguments): Training configuration, including output directory,
477
501
  evaluation strategy, and other training parameters.
502
+ cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
478
503
  tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
479
504
  dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
480
505
  data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
@@ -494,7 +519,7 @@ def retrain_with_full_set(
494
519
  training_args.evaluation_strategy = "no"
495
520
  checkpoint = get_last_hf_checkpoint(training_args)
496
521
  final_trainer = trainer_class(
497
- model=model_init(model_args, training_args, tokenizer),
522
+ model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
498
523
  data_collator=data_collator,
499
524
  args=training_args,
500
525
  train_dataset=full_dataset,
@@ -34,6 +34,7 @@ from cehrgpt.models.config import CEHRGPTConfig
34
34
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
35
35
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
36
36
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
37
+ from cehrgpt.runners.data_utils import get_torch_dtype
37
38
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
38
39
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
39
40
  from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
@@ -71,6 +72,36 @@ def load_and_create_tokenizer(
71
72
  cehrgpt_args: CehrGPTArguments,
72
73
  dataset: Optional[Union[Dataset, DatasetDict]] = None,
73
74
  ) -> CehrGptTokenizer:
75
+
76
+ concept_name_mapping = {}
77
+ allowed_motor_codes = list()
78
+ if cehrgpt_args.concept_dir:
79
+ import pandas as pd
80
+ from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
81
+ from meds.schema import death_code
82
+
83
+ LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
84
+ concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
85
+ LOG.info(
86
+ "Creating concept name mapping and motor_time_to_event_codes from disk at %s",
87
+ cehrgpt_args.concept_dir,
88
+ )
89
+ for row in concept_pd.itertuples():
90
+ concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
91
+ row, "concept_name"
92
+ )
93
+ if (
94
+ cehrgpt_args.include_motor_time_to_event
95
+ and getattr(row, "domain_id")
96
+ in ["Condition", "Procedure", "Drug", "Visit"]
97
+ and getattr(row, "standard_concept") == "S"
98
+ ):
99
+ allowed_motor_codes.append(str(getattr(row, "concept_id")))
100
+ LOG.info(
101
+ "Adding death codes for MOTOR TTE predictions: %s",
102
+ [DEATH_TOKEN, death_code],
103
+ )
104
+ allowed_motor_codes.extend([DEATH_TOKEN, death_code])
74
105
  # Try to load the pretrained tokenizer
75
106
  tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
76
107
  try:
@@ -85,9 +116,17 @@ def load_and_create_tokenizer(
85
116
  LOG.info("Started training the tokenizer ...")
86
117
  tokenizer = CehrGptTokenizer.train_tokenizer(
87
118
  dataset,
88
- {},
119
+ concept_name_mapping,
89
120
  data_args,
90
121
  PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
122
+ allowed_motor_codes if cehrgpt_args.include_motor_time_to_event else None,
123
+ (
124
+ cehrgpt_args.num_motor_tasks
125
+ if cehrgpt_args.include_motor_time_to_event
126
+ else None
127
+ ),
128
+ apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
129
+ min_prevalence=cehrgpt_args.min_prevalence,
91
130
  )
92
131
  LOG.info("Finished training the tokenizer ...")
93
132
  tokenizer.save_pretrained(tokenizer_abspath)
@@ -99,13 +138,12 @@ def load_and_create_tokenizer(
99
138
  def load_and_create_model(
100
139
  model_args: ModelArguments,
101
140
  cehrgpt_args: CehrGPTArguments,
102
- training_args: TrainingArguments,
103
141
  tokenizer: CehrGptTokenizer,
104
142
  ) -> CEHRGPT2LMHeadModel:
105
143
  attn_implementation = (
106
144
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
107
145
  )
108
- torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
146
+ torch_dtype = get_torch_dtype(model_args.torch_dtype)
109
147
  model_abspath = os.path.expanduser(model_args.model_name_or_path)
110
148
  if cehrgpt_args.continue_pretrain:
111
149
  try:
@@ -147,6 +185,8 @@ def load_and_create_model(
147
185
  else:
148
186
  pretrained_embedding_dim = model_args.hidden_size
149
187
 
188
+ model_args_cehrgpt = model_args.as_dict()
189
+ model_args_cehrgpt.pop("attn_implementation")
150
190
  model_config = CEHRGPTConfig(
151
191
  vocab_size=tokenizer.vocab_size,
152
192
  value_vocab_size=tokenizer.value_vocab_size,
@@ -172,7 +212,12 @@ def load_and_create_model(
172
212
  if cehrgpt_args.sample_packing
173
213
  else model_args.max_position_embeddings
174
214
  ),
175
- **model_args.as_dict(),
215
+ include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
216
+ motor_tte_vocab_size=tokenizer.motor_tte_vocab_size,
217
+ motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
218
+ motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
219
+ ve_token_id=tokenizer.ve_token_id,
220
+ **model_args_cehrgpt,
176
221
  )
177
222
 
178
223
  model = CEHRGPT2LMHeadModel(model_config)
@@ -348,6 +393,8 @@ def main():
348
393
  pretrained_concept_embedding_model=PretrainedEmbeddings(
349
394
  cehrgpt_args.pretrained_embedding_path
350
395
  ),
396
+ apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
397
+ min_prevalence=cehrgpt_args.min_prevalence,
351
398
  )
352
399
  cehrgpt_tokenizer.save_pretrained(
353
400
  os.path.expanduser(training_args.output_dir)
@@ -421,9 +468,11 @@ def main():
421
468
  else:
422
469
  processed_dataset = processed_dataset.filter(filter_func, **filter_args)
423
470
 
424
- model = load_and_create_model(
425
- model_args, cehrgpt_args, training_args, cehrgpt_tokenizer
426
- )
471
+ model = load_and_create_model(model_args, cehrgpt_args, cehrgpt_tokenizer)
472
+
473
+ # Try to update motor tte vocab size if the new configuration is different from the existing one
474
+ if cehrgpt_args.include_motor_time_to_event:
475
+ model.update_motor_tte_vocab_size(cehrgpt_tokenizer.motor_tte_vocab_size)
427
476
 
428
477
  # Expand tokenizer to adapt to the new pretraining dataset
429
478
  if model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
@@ -500,6 +549,9 @@ def main():
500
549
  include_ttv_prediction=model_args.include_ttv_prediction,
501
550
  use_sub_time_tokenization=model_args.use_sub_time_tokenization,
502
551
  include_values=model_args.include_values,
552
+ include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
553
+ motor_tte_vocab_size=model.config.motor_tte_vocab_size,
554
+ motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
503
555
  ),
504
556
  train_dataset=processed_dataset["train"],
505
557
  eval_dataset=(
@@ -6,6 +6,12 @@ from typing import List, Optional
6
6
  class CehrGPTArguments:
7
7
  """Arguments pertaining to what data we are going to input our model for training and eval."""
8
8
 
9
+ tokenized_full_dataset_path: Optional[str] = dataclasses.field(
10
+ default=None,
11
+ metadata={
12
+ "help": "The path to the tokenized dataset created for the full population"
13
+ },
14
+ )
9
15
  include_inpatient_hour_token: Optional[bool] = dataclasses.field(
10
16
  default=True,
11
17
  metadata={"help": "Include inpatient hour token"},
@@ -177,7 +183,49 @@ class CehrGPTArguments:
177
183
  "help": "A flag to indicate whether we want to add end token in sample packing"
178
184
  },
179
185
  )
186
+ include_motor_time_to_event: Optional[bool] = dataclasses.field(
187
+ default=False,
188
+ metadata={
189
+ "help": "A flag to indicate whether we want to include the motor time to events"
190
+ },
191
+ )
192
+ num_motor_tasks: Optional[int] = dataclasses.field(
193
+ default=10000,
194
+ metadata={"help": "The number of max MOTOR tasks"},
195
+ )
196
+ motor_time_to_event_weight: Optional[float] = dataclasses.field(
197
+ default=1.0,
198
+ metadata={"help": "The MOTOR time to event loss weight"},
199
+ )
200
+ motor_num_time_pieces: Optional[int] = dataclasses.field(
201
+ default=8,
202
+ metadata={
203
+ "help": "The number of times each motor_num_time_pieces piece has to be"
204
+ },
205
+ )
206
+ concept_dir: Optional[str] = dataclasses.field(
207
+ default=None,
208
+ metadata={"help": "The directory where the concept data is stored."},
209
+ )
180
210
  average_over_sequence: bool = dataclasses.field(
181
211
  default=False,
182
212
  metadata={"help": "Whether or not to average tokens per sequence"},
183
213
  )
214
+ apply_entropy_filter: Optional[bool] = dataclasses.field(
215
+ default=False,
216
+ metadata={"help": "A flag to indicate whether we want to use entropy filter."},
217
+ )
218
+ min_prevalence: Optional[float] = dataclasses.field(
219
+ default=1 / 1000,
220
+ metadata={"help": "The min_prevalence to keep the concepts in the tokenizer"},
221
+ )
222
+ class_weights: Optional[List[int]] = dataclasses.field(
223
+ default=None,
224
+ metadata={"help": "The class weights for training"},
225
+ )
226
+ negative_sampling_probability: Optional[float] = dataclasses.field(
227
+ default=None,
228
+ metadata={
229
+ "help": "The probability of negative samples will be included in the training data"
230
+ },
231
+ )
@@ -4,12 +4,7 @@ from typing import Callable, Tuple
4
4
  import optuna
5
5
  from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
6
6
  from datasets import Dataset, DatasetDict
7
- from transformers import (
8
- EarlyStoppingCallback,
9
- Trainer,
10
- TrainerCallback,
11
- TrainingArguments,
12
- )
7
+ from transformers import EarlyStoppingCallback, TrainerCallback, TrainingArguments
13
8
  from transformers.utils import logging
14
9
 
15
10
  from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
@@ -85,7 +80,9 @@ def hp_space(
85
80
  "per_device_train_batch_size", batch_sizes
86
81
  ),
87
82
  "weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
88
- "num_train_epochs": trial.suggest_int("num_train_epochs", *num_train_epochs),
83
+ "num_train_epochs": trial.suggest_categorical(
84
+ "num_train_epochs", num_train_epochs
85
+ ),
89
86
  }
90
87
 
91
88
 
@@ -217,6 +214,8 @@ def perform_hyperparameter_search(
217
214
  backend="optuna",
218
215
  n_trials=cehrgpt_args.n_trials,
219
216
  compute_objective=lambda m: m["optuna_best_metric"],
217
+ # Ensure reproducibility
218
+ sampler=optuna.samplers.TPESampler(seed=training_args.seed),
220
219
  )
221
220
  LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
222
221
  # Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
@@ -35,6 +35,13 @@ class SamplePackingTrainer(Trainer):
35
35
  self.max_tokens_per_batch,
36
36
  )
37
37
 
38
+ self.negative_sampling_probability = kwargs.pop(
39
+ "negative_sampling_probability", None
40
+ )
41
+ if self.negative_sampling_probability:
42
+ LOG.info(
43
+ "negative_sampling_probability: %s", self.negative_sampling_probability
44
+ )
38
45
  self.train_lengths = kwargs.pop("train_lengths", None)
39
46
  self.validation_lengths = kwargs.pop("validation_lengths", None)
40
47
  super().__init__(*args, **kwargs)
@@ -70,6 +77,14 @@ class SamplePackingTrainer(Trainer):
70
77
  data_collator = self._get_collator_with_removed_columns(
71
78
  data_collator, description="training"
72
79
  )
80
+
81
+ labels = None
82
+ if (
83
+ self.negative_sampling_probability is not None
84
+ and "classifier_label" in train_dataset.column_names
85
+ ):
86
+ labels = train_dataset["classifier_label"]
87
+
73
88
  # Create our custom batch sampler
74
89
  batch_sampler = SamplePackingBatchSampler(
75
90
  lengths=lengths,
@@ -77,6 +92,8 @@ class SamplePackingTrainer(Trainer):
77
92
  max_position_embeddings=self.max_position_embeddings,
78
93
  drop_last=self.args.dataloader_drop_last,
79
94
  seed=self.args.seed,
95
+ negative_sampling_probability=self.negative_sampling_probability,
96
+ labels=labels,
80
97
  )
81
98
  dataloader_params = {
82
99
  "collate_fn": data_collator,
@@ -0,0 +1,23 @@
1
+ task_name: "cabg_prediction"
2
+ outcome_events: [
3
+ "43528001",
4
+ "43528003",
5
+ "43528004",
6
+ "43528002",
7
+ "4305852",
8
+ "4168831",
9
+ "2107250",
10
+ "2107216",
11
+ "2107222",
12
+ "2107231",
13
+ "4336464",
14
+ "4231998",
15
+ "4284104",
16
+ "2100873",
17
+ ]
18
+ future_visit_start: 0
19
+ future_visit_end: -1
20
+ prediction_window_start: 0
21
+ prediction_window_end: 365
22
+ max_new_tokens: 1024
23
+ include_descendants: true