cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  5. cehrgpt/data/sample_packing_sampler.py +36 -6
  6. cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  7. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  8. cehrgpt/generation/omop_converter_batch.py +32 -2
  9. cehrgpt/gpt_utils.py +20 -2
  10. cehrgpt/models/config.py +25 -0
  11. cehrgpt/models/hf_cehrgpt.py +244 -39
  12. cehrgpt/models/hf_modeling_outputs.py +1 -0
  13. cehrgpt/models/special_tokens.py +1 -0
  14. cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  15. cehrgpt/runners/data_utils.py +131 -5
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +6 -7
  20. cehrgpt/runners/sample_packing_trainer.py +17 -0
  21. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  22. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  23. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  25. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
  26. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
  27. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
  28. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
  29. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,314 @@
1
+ import datetime
2
+ import os
3
+ import random
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Any, Dict
7
+
8
+ import numpy as np
9
+ import polars as pl
10
+ import torch
11
+ import torch.distributed as dist
12
+ from cehrbert.runners.runner_util import generate_prepared_ds_path
13
+ from datasets import load_from_disk
14
+ from meds import held_out_split, train_split, tuning_split
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+ from transformers.trainer_utils import is_main_process
18
+ from transformers.utils import is_flash_attn_2_available, logging
19
+
20
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
21
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
22
+ from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
23
+ generate_single_batch,
24
+ normalize_value,
25
+ )
26
+ from cehrgpt.gpt_utils import (
27
+ extract_time_interval_in_days,
28
+ extract_time_interval_in_hours,
29
+ is_att_token,
30
+ is_inpatient_hour_token,
31
+ is_visit_end,
32
+ is_visit_start,
33
+ )
34
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
35
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
36
+ from cehrgpt.runners.data_utils import (
37
+ extract_cohort_sequences,
38
+ prepare_finetune_dataset,
39
+ )
40
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
41
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
42
+
43
+ LOG = logging.get_logger("transformers")
44
+
45
+
46
+ def map_data_split_name(split: str) -> str:
47
+ if split == "train":
48
+ return train_split
49
+ elif split == "validation":
50
+ return tuning_split
51
+ elif split == "test":
52
+ return held_out_split
53
+ raise ValueError(f"Unknown split: {split}")
54
+
55
+
56
+ def seed_all(seed: int = 42):
57
+ """Set seed for Python, NumPy, and PyTorch (CPU & CUDA)."""
58
+ random.seed(seed) # Python random
59
+ np.random.seed(seed) # NumPy
60
+ torch.manual_seed(seed) # PyTorch CPU
61
+ torch.cuda.manual_seed(seed) # Current GPU
62
+ torch.cuda.manual_seed_all(seed) # All GPUs
63
+
64
+ # For reproducibility in dataloader workers
65
+ os.environ["PYTHONHASHSEED"] = str(seed)
66
+
67
+
68
+ def generate_trajectories_per_batch(
69
+ batch: Dict[str, Any],
70
+ cehrgpt_tokenizer: CehrGptTokenizer,
71
+ cehrgpt_model: CEHRGPT2LMHeadModel,
72
+ device,
73
+ data_output_path: Path,
74
+ max_length: int,
75
+ ):
76
+ subject_ids = batch["person_id"].squeeze().detach().cpu().tolist()
77
+ prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
78
+ batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
79
+ batched_input_ids = batch["input_ids"]
80
+ batched_value_indicators = batch["value_indicators"]
81
+ batched_values = batch["values"]
82
+ # Make sure the batch does not exceed batch_size
83
+ batch_sequences = generate_single_batch(
84
+ cehrgpt_model,
85
+ cehrgpt_tokenizer,
86
+ batched_input_ids,
87
+ values=batched_values,
88
+ value_indicators=batched_value_indicators,
89
+ max_length=max_length,
90
+ top_p=1.0,
91
+ top_k=cehrgpt_tokenizer.vocab_size,
92
+ device=device,
93
+ )
94
+ # Clear the cache
95
+ torch.cuda.empty_cache()
96
+
97
+ trajectories = []
98
+ for sample_i, (concept_ids, value_indicators, values) in enumerate(
99
+ zip(
100
+ batch_sequences["sequences"],
101
+ batch_sequences["value_indicators"],
102
+ batch_sequences["values"],
103
+ )
104
+ ):
105
+ (
106
+ concept_ids,
107
+ is_numeric_types,
108
+ number_as_values,
109
+ concept_as_values,
110
+ units,
111
+ ) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
112
+
113
+ epoch_times = batched_epoch_times[sample_i]
114
+ input_length = len(epoch_times)
115
+ # Getting the last observed event time from the token before the prediction time
116
+ window_last_observed = epoch_times[input_length - 1]
117
+ current_cursor = epoch_times[-1]
118
+ generated_epoch_times = []
119
+ valid_indices = []
120
+
121
+ for i in range(input_length, len(concept_ids)):
122
+ concept_id = concept_ids[i]
123
+ # We use the left padding strategy in the data collator
124
+ if concept_id in [cehrgpt_tokenizer.pad_token, cehrgpt_tokenizer.end_token]:
125
+ continue
126
+ # We need to construct the time stamp
127
+ if is_att_token(concept_id):
128
+ current_cursor += extract_time_interval_in_days(concept_id) * 24 * 3600
129
+ elif is_inpatient_hour_token(concept_id):
130
+ current_cursor += extract_time_interval_in_hours(concept_id) * 3600
131
+ elif is_visit_start(concept_id) or is_visit_end(concept_id):
132
+ continue
133
+ else:
134
+ valid_indices.append(i)
135
+ generated_epoch_times.append(
136
+ datetime.datetime.utcfromtimestamp(current_cursor).replace(
137
+ tzinfo=None
138
+ )
139
+ )
140
+
141
+ trajectories.append(
142
+ {
143
+ "subject_id": subject_ids[sample_i],
144
+ "prediction_time": datetime.datetime.utcfromtimestamp(
145
+ prediction_times[sample_i]
146
+ ).replace(tzinfo=None),
147
+ "window_last_observed_time": datetime.datetime.utcfromtimestamp(
148
+ window_last_observed
149
+ ).replace(tzinfo=None),
150
+ "times": generated_epoch_times,
151
+ "concept_ids": np.asarray(concept_ids)[valid_indices].tolist(),
152
+ "numeric_values": np.asarray(number_as_values)[valid_indices].tolist(),
153
+ "text_value": np.asarray(concept_as_values)[valid_indices].tolist(),
154
+ "units": np.asarray(units)[valid_indices].tolist(),
155
+ }
156
+ )
157
+
158
+ trajectories = (
159
+ pl.DataFrame(trajectories)
160
+ .explode(["times", "concept_ids", "numeric_values", "text_value", "units"])
161
+ .rename(
162
+ {
163
+ "times": "time",
164
+ "concept_ids": "code",
165
+ "numeric_values": "numeric_value",
166
+ "units": "unit",
167
+ }
168
+ )
169
+ .select(
170
+ "subject_id",
171
+ "prediction_time",
172
+ "window_last_observed_time",
173
+ "time",
174
+ "code",
175
+ "numeric_value",
176
+ "text_value",
177
+ "unit",
178
+ )
179
+ )
180
+ trajectories.write_parquet(data_output_path)
181
+
182
+
183
+ def main():
184
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
185
+ if torch.cuda.is_available():
186
+ device = torch.device("cuda")
187
+ else:
188
+ device = torch.device("cpu")
189
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
190
+ model_args.tokenizer_name_or_path
191
+ )
192
+ cehrgpt_model = (
193
+ CEHRGPT2LMHeadModel.from_pretrained(
194
+ model_args.model_name_or_path,
195
+ attn_implementation=(
196
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
197
+ ),
198
+ )
199
+ .eval()
200
+ .to(device)
201
+ )
202
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
203
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
204
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
205
+
206
+ if not os.path.exists(training_args.output_dir):
207
+ os.makedirs(training_args.output_dir)
208
+
209
+ prepared_ds_path = generate_prepared_ds_path(
210
+ data_args, model_args, data_folder=data_args.cohort_folder
211
+ )
212
+
213
+ processed_dataset = None
214
+ if any(prepared_ds_path.glob("*")):
215
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
216
+ processed_dataset = load_from_disk(str(prepared_ds_path))
217
+ LOG.info("Prepared dataset loaded from disk...")
218
+ if cehrgpt_args.expand_tokenizer:
219
+ if tokenizer_exists(training_args.output_dir):
220
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
221
+ training_args.output_dir
222
+ )
223
+ else:
224
+ LOG.warning(
225
+ f"CehrGptTokenizer must exist in {training_args.output_dir} "
226
+ f"when the dataset has been processed and expand_tokenizer is set to True. "
227
+ f"Please delete the processed dataset at {prepared_ds_path}."
228
+ )
229
+ processed_dataset = None
230
+ shutil.rmtree(prepared_ds_path)
231
+
232
+ if processed_dataset is None and is_main_process(training_args.local_rank):
233
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
234
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
235
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
236
+ processed_dataset = extract_cohort_sequences(data_args, cehrgpt_args)
237
+ else:
238
+ # Organize them into a single DatasetDict
239
+ final_splits = prepare_finetune_dataset(
240
+ data_args, training_args, cehrgpt_args
241
+ )
242
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
243
+ if not data_args.streaming:
244
+ all_columns = final_splits["train"].column_names
245
+ if "visit_concept_ids" in all_columns:
246
+ final_splits = final_splits.remove_columns(["visit_concept_ids"])
247
+
248
+ processed_dataset = create_cehrgpt_finetuning_dataset(
249
+ dataset=final_splits,
250
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
251
+ data_args=data_args,
252
+ )
253
+ if not data_args.streaming:
254
+ processed_dataset.save_to_disk(prepared_ds_path)
255
+ processed_dataset.cleanup_cache_files()
256
+
257
+ # After main-process-only operations, synchronize all processes to ensure consistency
258
+ if dist.is_available() and dist.is_initialized():
259
+ dist.barrier()
260
+
261
+ # We suppress the additional learning objectives in fine-tuning
262
+ data_collator = CehrGptDataCollator(
263
+ tokenizer=cehrgpt_tokenizer,
264
+ max_length=cehrgpt_args.generation_input_length,
265
+ include_values=cehrgpt_model.config.include_values,
266
+ pretraining=False,
267
+ include_ttv_prediction=False,
268
+ use_sub_time_tokenization=False,
269
+ include_demographics=False,
270
+ add_linear_prob_token=False,
271
+ )
272
+
273
+ LOG.info(
274
+ "Generating %s trajectories per sample",
275
+ cehrgpt_args.num_of_trajectories_per_sample,
276
+ )
277
+ for sample_i in range(cehrgpt_args.num_of_trajectories_per_sample):
278
+ for split, dataset in processed_dataset.items():
279
+ meds_split = map_data_split_name(split)
280
+ dataloader = DataLoader(
281
+ dataset=dataset,
282
+ batch_size=training_args.per_device_eval_batch_size,
283
+ num_workers=training_args.dataloader_num_workers,
284
+ collate_fn=data_collator,
285
+ pin_memory=training_args.dataloader_pin_memory,
286
+ )
287
+ sample_output_dir = (
288
+ Path(training_args.output_dir) / meds_split / f"{sample_i}"
289
+ )
290
+ sample_output_dir.mkdir(exist_ok=True, parents=True)
291
+ for batch_i, batch in tqdm(
292
+ enumerate(dataloader),
293
+ desc=f"Generating Trajectories for split {meds_split} with trajectory {sample_i + 1}",
294
+ ):
295
+ output_parquet_file = sample_output_dir / f"{batch_i}.parquet"
296
+ if output_parquet_file.exists():
297
+ LOG.info("%s already exists, skip...", output_parquet_file)
298
+ continue
299
+
300
+ generate_trajectories_per_batch(
301
+ batch,
302
+ cehrgpt_tokenizer,
303
+ cehrgpt_model,
304
+ device,
305
+ sample_output_dir / f"{batch_i}.parquet",
306
+ cehrgpt_args.generation_max_new_tokens
307
+ + cehrgpt_args.generation_input_length,
308
+ )
309
+
310
+
311
+ if __name__ == "__main__":
312
+ # ✅ Call first thing inside main()
313
+ seed_all(42)
314
+ main()
@@ -74,7 +74,10 @@ def generate_single_batch(
74
74
  model: CEHRGPT2LMHeadModel,
75
75
  tokenizer: CehrGptTokenizer,
76
76
  prompts: List[List[int]],
77
- max_new_tokens=512,
77
+ max_length: int,
78
+ values: Optional[torch.Tensor] = None,
79
+ value_indicators: Optional[torch.Tensor] = None,
80
+ max_new_tokens: Optional[int] = None,
78
81
  mini_num_of_concepts=1,
79
82
  top_p=0.95,
80
83
  top_k=50,
@@ -88,7 +91,8 @@ def generate_single_batch(
88
91
  with torch.no_grad():
89
92
  generation_config = GenerationConfig(
90
93
  repetition_penalty=repetition_penalty,
91
- max_length=max_new_tokens,
94
+ max_new_tokens=max_new_tokens,
95
+ max_length=max_length,
92
96
  min_length=mini_num_of_concepts,
93
97
  temperature=temperature,
94
98
  top_p=top_p,
@@ -107,9 +111,17 @@ def generate_single_batch(
107
111
  num_beam_groups=num_beam_groups,
108
112
  epsilon_cutoff=epsilon_cutoff,
109
113
  )
114
+
110
115
  batched_prompts = torch.tensor(prompts).to(device)
116
+ if values is not None:
117
+ values = values.to(device)
118
+ if value_indicators is not None:
119
+ value_indicators = value_indicators.to(device)
120
+
111
121
  results = model.generate(
112
122
  inputs=batched_prompts,
123
+ values=values,
124
+ value_indicators=value_indicators,
113
125
  generation_config=generation_config,
114
126
  lab_token_ids=tokenizer.lab_token_ids,
115
127
  )
@@ -226,7 +238,7 @@ def main(args):
226
238
  cehrgpt_model,
227
239
  cehrgpt_tokenizer,
228
240
  random_prompts[: args.batch_size],
229
- max_new_tokens=args.context_window,
241
+ max_length=args.context_window,
230
242
  mini_num_of_concepts=args.min_num_of_concepts,
231
243
  top_p=args.top_p,
232
244
  top_k=args.top_k,
@@ -60,6 +60,24 @@ OOV_CONCEPT_MAP = {
60
60
  }
61
61
 
62
62
 
63
+ def extract_gender_concept_id(gender_token: str) -> int:
64
+ if gender_token.startswith("Gender/"):
65
+ return int(gender_token[len("Gender/") :])
66
+ elif gender_token.isnumeric():
67
+ return int(gender_token)
68
+ else:
69
+ return 0
70
+
71
+
72
+ def extract_race_concept_id(race_token: str) -> int:
73
+ if race_token.startswith("Race/"):
74
+ return int(race_token[len("Race/") :])
75
+ elif race_token.isnumeric():
76
+ return int(race_token)
77
+ else:
78
+ return 0
79
+
80
+
63
81
  def create_folder_if_not_exists(output_folder, table_name):
64
82
  if not os.path.isdir(Path(output_folder) / table_name):
65
83
  os.mkdir(Path(output_folder) / table_name)
@@ -288,7 +306,13 @@ def gpt_to_omop_converter_batch(
288
306
  if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
289
307
  continue
290
308
 
291
- p = Person(person_id, start_gender, birth_year, start_race)
309
+ p = Person(
310
+ person_id=person_id,
311
+ gender_concept_id=extract_gender_concept_id(start_gender),
312
+ year_of_birth=birth_year,
313
+ race_concept_id=extract_race_concept_id(start_race),
314
+ )
315
+
292
316
  append_to_dict(omop_export_dict, p, person_id)
293
317
  id_mappings_dict["person"][person_id] = person_id
294
318
  pt_seq_dict[person_id] = " ".join(concept_ids)
@@ -316,7 +340,12 @@ def gpt_to_omop_converter_batch(
316
340
  id_mappings_dict["death"][person_id] = person_id
317
341
  else:
318
342
  try:
319
- visit_concept_id = int(clinical_events[event_idx + 1])
343
+ if clinical_events[event_idx + 1].startswith("Visit/"):
344
+ visit_concept_id = int(
345
+ clinical_events[event_idx + 1][len("Visit/") :]
346
+ )
347
+ else:
348
+ visit_concept_id = int(clinical_events[event_idx + 1])
320
349
  inpatient_visit_indicator = visit_concept_id in [
321
350
  9201,
322
351
  262,
@@ -349,6 +378,7 @@ def gpt_to_omop_converter_batch(
349
378
  visit_occurrence_id
350
379
  ] = person_id
351
380
  visit_occurrence_id += 1
381
+
352
382
  elif event in ATT_TIME_TOKENS:
353
383
  if event[0] == "D":
354
384
  att_date_delta = int(event[1:])
cehrgpt/gpt_utils.py CHANGED
@@ -11,6 +11,7 @@ from cehrgpt.models.special_tokens import (
11
11
  )
12
12
 
13
13
  # Regular expression pattern to match inpatient attendance tokens
14
+ MEDS_CODE_PATTERN = re.compile(r".*/.*")
14
15
  INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
15
16
  DEMOGRAPHIC_PROMPT_SIZE = 4
16
17
 
@@ -194,8 +195,12 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
194
195
  return folder_name
195
196
 
196
197
 
197
- def is_clinical_event(token: str) -> bool:
198
- return token.isnumeric()
198
+ def is_clinical_event(token: str, meds: bool = False) -> bool:
199
+ if token.isnumeric():
200
+ return True
201
+ if meds:
202
+ return bool(MEDS_CODE_PATTERN.match(token))
203
+ return False
199
204
 
200
205
 
201
206
  def is_visit_start(token: str):
@@ -212,6 +217,18 @@ def is_visit_end(token: str) -> bool:
212
217
  return token in ["VE", "[VE]"]
213
218
 
214
219
 
220
+ def is_inpatient_hour_token(token: str) -> bool:
221
+ return token.startswith("i-H")
222
+
223
+
224
+ def extract_time_interval_in_hours(token: str) -> int:
225
+ try:
226
+ hour = int(token[3:])
227
+ return hour
228
+ except ValueError:
229
+ return 0
230
+
231
+
215
232
  def is_att_token(token: str):
216
233
  """
217
234
  Check if the token is an attention token.
@@ -251,6 +268,7 @@ def is_artificial_token(token: str) -> bool:
251
268
  return True
252
269
  if token == END_TOKEN:
253
270
  return True
271
+
254
272
  return False
255
273
 
256
274
 
cehrgpt/models/config.py CHANGED
@@ -121,6 +121,7 @@ class CEHRGPTConfig(PretrainedConfig):
121
121
  bos_token_id=50256,
122
122
  eos_token_id=50256,
123
123
  lab_token_ids=None,
124
+ ve_token_id=None,
124
125
  scale_attn_by_inverse_layer_idx=False,
125
126
  reorder_and_upcast_attn=False,
126
127
  exclude_position_ids=False,
@@ -128,6 +129,10 @@ class CEHRGPTConfig(PretrainedConfig):
128
129
  value_vocab_size=None,
129
130
  include_ttv_prediction=False,
130
131
  use_sub_time_tokenization=True,
132
+ include_motor_time_to_event=True,
133
+ motor_tte_vocab_size=None,
134
+ motor_time_to_event_weight=1.0,
135
+ motor_num_time_pieces=16,
131
136
  token_to_time_token_mapping: Dict[int, List] = None,
132
137
  use_pretrained_embeddings=False,
133
138
  n_pretrained_embeddings_layers=2,
@@ -144,6 +149,7 @@ class CEHRGPTConfig(PretrainedConfig):
144
149
  entropy_penalty=False,
145
150
  entropy_penalty_alpha=0.01,
146
151
  sample_packing_max_positions=None,
152
+ class_weights=None,
147
153
  **kwargs,
148
154
  ):
149
155
  if token_to_time_token_mapping is None:
@@ -192,6 +198,22 @@ class CEHRGPTConfig(PretrainedConfig):
192
198
  self._token_to_time_token_mapping = token_to_time_token_mapping
193
199
  self.time_token_loss_weight = time_token_loss_weight
194
200
  self.time_to_visit_loss_weight = time_to_visit_loss_weight
201
+
202
+ # MOTOR TTE configuration
203
+ self.motor_tte_vocab_size = motor_tte_vocab_size
204
+ self.include_motor_time_to_event = (
205
+ include_motor_time_to_event
206
+ and self.motor_tte_vocab_size
207
+ and self.motor_tte_vocab_size > 0
208
+ )
209
+ if self.include_motor_time_to_event and not ve_token_id:
210
+ raise RuntimeError(
211
+ f"ve_token_id must be provided when include_motor_time_to_event is True"
212
+ )
213
+ self.ve_token_id = ve_token_id
214
+ self.motor_time_to_event_weight = motor_time_to_event_weight
215
+ self.motor_num_time_pieces = motor_num_time_pieces
216
+
195
217
  self.causal_sfm = causal_sfm
196
218
  self.demographics_size = demographics_size
197
219
  self.use_pretrained_embeddings = use_pretrained_embeddings
@@ -206,6 +228,9 @@ class CEHRGPTConfig(PretrainedConfig):
206
228
  self.entropy_penalty_alpha = entropy_penalty_alpha
207
229
  self.value_prediction_loss_weight = value_prediction_loss_weight
208
230
 
231
+ # Class weights for fine-tuning
232
+ self.class_weights = class_weights
233
+
209
234
  kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
210
235
 
211
236
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)