cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,12 @@ from datasets.formatting.formatting import LazyBatch
28
28
  from dateutil.relativedelta import relativedelta
29
29
  from pandas import Series
30
30
 
31
+ from cehrgpt.gpt_utils import (
32
+ construct_age_sequence,
33
+ construct_time_sequence,
34
+ encode_demographics,
35
+ multiple_of_10,
36
+ )
31
37
  from cehrgpt.models.tokenization_hf_cehrgpt import (
32
38
  NONE_BIN,
33
39
  UNKNOWN_BIN,
@@ -43,6 +49,7 @@ CEHRGPT_COLUMNS = [
43
49
  "concept_values",
44
50
  "units",
45
51
  "epoch_times",
52
+ "ages",
46
53
  ]
47
54
 
48
55
 
@@ -121,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
121
128
  cehrgpt_record: Dict[str, Any],
122
129
  code: str,
123
130
  time: datetime.datetime,
131
+ age: int,
124
132
  concept_value_mask: int = 0,
125
133
  number_as_value: float = 0.0,
126
134
  concept_as_value: str = "0",
@@ -128,6 +136,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
128
136
  unit: str = NA,
129
137
  ) -> None:
130
138
  cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
139
+ cehrgpt_record["ages"].append(age)
131
140
  cehrgpt_record["concept_value_masks"].append(concept_value_mask)
132
141
  cehrgpt_record["number_as_values"].append(number_as_value)
133
142
  cehrgpt_record["concept_as_values"].append(concept_as_value)
@@ -141,6 +150,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
141
150
  cehrgpt_record = {
142
151
  "person_id": record["patient_id"],
143
152
  "concept_ids": [],
153
+ "ages": [],
144
154
  "concept_value_masks": [],
145
155
  "number_as_values": [],
146
156
  "concept_as_values": [],
@@ -168,14 +178,21 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
168
178
  first_visit_start_datetime: datetime.datetime = get_value(
169
179
  first_visit, "visit_start_datetime"
170
180
  )
181
+ starting_age = relativedelta(first_visit_start_datetime, birth_datetime).years
171
182
  year_str = f"year:{str(first_visit_start_datetime.year)}"
172
- age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
183
+ age_str = f"age:{starting_age}"
184
+ self._update_cehrgpt_record(
185
+ cehrgpt_record, year_str, first_visit_start_datetime, starting_age
186
+ )
187
+ self._update_cehrgpt_record(
188
+ cehrgpt_record, age_str, first_visit_start_datetime, starting_age
189
+ )
173
190
  self._update_cehrgpt_record(
174
- cehrgpt_record, year_str, first_visit_start_datetime
191
+ cehrgpt_record, gender, first_visit_start_datetime, starting_age
192
+ )
193
+ self._update_cehrgpt_record(
194
+ cehrgpt_record, race, first_visit_start_datetime, starting_age
175
195
  )
176
- self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
177
- self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
178
- self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
179
196
 
180
197
  # Use a data cursor to keep track of time
181
198
  datetime_cursor: Optional[datetime.datetime] = None
@@ -211,6 +228,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
211
228
  cehrgpt_record,
212
229
  code=self._time_token_function(time_delta),
213
230
  time=visit_start_datetime,
231
+ age=relativedelta(datetime_cursor, birth_datetime).years,
214
232
  )
215
233
 
216
234
  datetime_cursor = visit_start_datetime
@@ -219,12 +237,14 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
219
237
  cehrgpt_record,
220
238
  code="[VS]",
221
239
  time=datetime_cursor,
240
+ age=relativedelta(datetime_cursor, birth_datetime).years,
222
241
  )
223
242
  # Add a visit type token
224
243
  self._update_cehrgpt_record(
225
244
  cehrgpt_record,
226
245
  code=visit_type,
227
246
  time=datetime_cursor,
247
+ age=relativedelta(datetime_cursor, birth_datetime).years,
228
248
  )
229
249
  # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
230
250
  # with respect to the midnight of the day
@@ -235,6 +255,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
235
255
  cehrgpt_record,
236
256
  code=f"i-H{datetime_cursor.hour}",
237
257
  time=datetime_cursor,
258
+ age=relativedelta(datetime_cursor, birth_datetime).years,
238
259
  )
239
260
 
240
261
  # Keep track of the existing outpatient events, we don't want to add them again
@@ -281,6 +302,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
281
302
  cehrgpt_record,
282
303
  code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
283
304
  time=event_time,
305
+ age=relativedelta(event_time, birth_datetime).years,
284
306
  )
285
307
 
286
308
  if self._include_inpatient_hour_token:
@@ -300,6 +322,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
300
322
  cehrgpt_record,
301
323
  code=f"i-H{time_diff_hours}",
302
324
  time=event_time,
325
+ age=relativedelta(event_time, birth_datetime).years,
303
326
  )
304
327
 
305
328
  if event_identity in existing_duplicate_events:
@@ -309,6 +332,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
309
332
  cehrgpt_record,
310
333
  code=code,
311
334
  time=event_time,
335
+ age=relativedelta(event_time, birth_datetime).years,
312
336
  concept_value_mask=concept_value_mask,
313
337
  unit=unit,
314
338
  number_as_value=numeric_value if numeric_value else 0.0,
@@ -348,6 +372,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
348
372
  cehrgpt_record,
349
373
  code=discharge_facility,
350
374
  time=datetime_cursor,
375
+ age=relativedelta(datetime_cursor, birth_datetime).years,
351
376
  )
352
377
 
353
378
  # Reuse the age and date calculated for the last event in the patient timeline
@@ -355,6 +380,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
355
380
  cehrgpt_record,
356
381
  code="[VE]",
357
382
  time=datetime_cursor,
383
+ age=relativedelta(datetime_cursor, birth_datetime).years,
358
384
  )
359
385
 
360
386
  # Generate the orders of the concepts that the cehrbert dataset mapping function expects
@@ -428,6 +454,13 @@ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
428
454
  return record
429
455
 
430
456
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
457
+ # Reconstruct the ages input before the filter is applied
458
+ record["ages"] = construct_age_sequence(
459
+ record["concept_ids"], record.get("ages", None)
460
+ )
461
+ record["epoch_times"] = construct_time_sequence(
462
+ record["concept_ids"], record.get("epoch_times", None)
463
+ )
431
464
  # Remove the tokens from patient sequences that do not exist in the tokenizer
432
465
  record = self.filter_out_invalid_tokens(record)
433
466
  # If any concept has a value associated with it, we normalize the value
@@ -77,6 +77,7 @@ def generate_trajectories_per_batch(
77
77
  prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
78
78
  batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
79
79
  batched_input_ids = batch["input_ids"]
80
+ batched_ages = batch["ages"]
80
81
  batched_value_indicators = batch["value_indicators"]
81
82
  batched_values = batch["values"]
82
83
  # Make sure the batch does not exceed batch_size
@@ -84,6 +85,7 @@ def generate_trajectories_per_batch(
84
85
  cehrgpt_model,
85
86
  cehrgpt_tokenizer,
86
87
  batched_input_ids,
88
+ ages=batched_ages,
87
89
  values=batched_values,
88
90
  value_indicators=batched_value_indicators,
89
91
  max_length=max_length,
@@ -2,7 +2,7 @@ import datetime
2
2
  import os
3
3
  import random
4
4
  import uuid
5
- from typing import Any, Dict, List, Optional, Sequence, Tuple
5
+ from typing import Any, Dict, Optional, Sequence, Tuple
6
6
 
7
7
  import numpy as np
8
8
  import pandas as pd
@@ -13,7 +13,7 @@ from transformers.utils import is_flash_attn_2_available, logging
13
13
 
14
14
  from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
15
15
  from cehrgpt.generation.omop_converter_batch import START_TOKEN_SIZE
16
- from cehrgpt.gpt_utils import get_cehrgpt_output_folder
16
+ from cehrgpt.gpt_utils import construct_age_sequence, get_cehrgpt_output_folder
17
17
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
18
  from cehrgpt.models.special_tokens import END_TOKEN
19
19
  from cehrgpt.models.tokenization_hf_cehrgpt import (
@@ -72,9 +72,10 @@ def normalize_value(
72
72
 
73
73
  def generate_single_batch(
74
74
  model: CEHRGPT2LMHeadModel,
75
- tokenizer: CehrGptTokenizer,
76
- prompts: List[List[int]],
75
+ cehrgpt_tokenizer: CehrGptTokenizer,
76
+ prompts: torch.Tensor,
77
77
  max_length: int,
78
+ ages: Optional[torch.Tensor] = None,
78
79
  values: Optional[torch.Tensor] = None,
79
80
  value_indicators: Optional[torch.Tensor] = None,
80
81
  max_new_tokens: Optional[int] = None,
@@ -112,7 +113,9 @@ def generate_single_batch(
112
113
  epsilon_cutoff=epsilon_cutoff,
113
114
  )
114
115
 
115
- batched_prompts = torch.tensor(prompts).to(device)
116
+ batched_prompts = prompts.to(device)
117
+ if ages is not None:
118
+ ages = ages.to(device)
116
119
  if values is not None:
117
120
  values = values.to(device)
118
121
  if value_indicators is not None:
@@ -120,19 +123,22 @@ def generate_single_batch(
120
123
 
121
124
  results = model.generate(
122
125
  inputs=batched_prompts,
126
+ ages=ages,
123
127
  values=values,
124
128
  value_indicators=value_indicators,
125
129
  generation_config=generation_config,
126
- lab_token_ids=tokenizer.lab_token_ids,
130
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
127
131
  )
128
132
 
129
133
  sequences = [
130
- tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
134
+ cehrgpt_tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
131
135
  for seq in results.sequences
132
136
  ]
133
137
  if results.sequence_vals is not None:
134
138
  values = [
135
- tokenizer.decode_value(values.cpu().numpy(), skip_special_tokens=False)
139
+ cehrgpt_tokenizer.decode_value(
140
+ values.cpu().numpy(), skip_special_tokens=False
141
+ )
136
142
  for values in results.sequence_vals
137
143
  ]
138
144
  else:
@@ -214,6 +220,7 @@ def main(args):
214
220
 
215
221
  # Randomly pick demographics from the existing population
216
222
  random_prompts = []
223
+ random_prompt_ages = []
217
224
  iter = 0
218
225
  while len(random_prompts) < args.batch_size:
219
226
  for row in dataset.select(
@@ -224,9 +231,9 @@ def main(args):
224
231
  <= len(row["concept_ids"])
225
232
  <= max_seq_allowed
226
233
  ):
227
- random_prompts.append(
228
- cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
229
- )
234
+ prompt = row["concept_ids"][:prompt_size]
235
+ random_prompts.append(cehrgpt_tokenizer.encode(prompt))
236
+ random_prompt_ages.append(construct_age_sequence(prompt))
230
237
  iter += 1
231
238
  if not random_prompts and iter > 10:
232
239
  raise RuntimeError(
@@ -237,7 +244,8 @@ def main(args):
237
244
  batch_sequences = generate_single_batch(
238
245
  cehrgpt_model,
239
246
  cehrgpt_tokenizer,
240
- random_prompts[: args.batch_size],
247
+ torch.tensor(random_prompts[: args.batch_size]),
248
+ ages=torch.tensor(random_prompt_ages[: args.batch_size]),
241
249
  max_length=args.context_window,
242
250
  mini_num_of_concepts=args.min_num_of_concepts,
243
251
  top_p=args.top_p,
@@ -270,20 +270,24 @@ def gpt_to_omop_converter_batch(
270
270
 
271
271
  is_numeric_types = (
272
272
  is_numeric_types[START_TOKEN_SIZE:]
273
- if is_numeric_types is not None
273
+ if is_numeric_types is not None and not np.all(pd.isna(is_numeric_types))
274
274
  else None
275
275
  )
276
276
  number_as_values = (
277
277
  number_as_values[START_TOKEN_SIZE:]
278
- if number_as_values is not None
278
+ if number_as_values is not None and not np.all(pd.isna(number_as_values))
279
279
  else None
280
280
  )
281
281
  concept_as_values = (
282
282
  concept_as_values[START_TOKEN_SIZE:]
283
- if concept_as_values is not None
283
+ if concept_as_values is not None and not np.all(pd.isna(concept_as_values))
284
+ else None
285
+ )
286
+ units = (
287
+ units[START_TOKEN_SIZE:]
288
+ if units is not None and not np.all(pd.isna(units))
284
289
  else None
285
290
  )
286
- units = units[START_TOKEN_SIZE:] if units is not None else None
287
291
 
288
292
  # TODO:Need to decode if the input is tokenized
289
293
  [start_year, start_age, start_gender, start_race] = concept_ids[
@@ -441,6 +445,9 @@ def gpt_to_omop_converter_batch(
441
445
  ]:
442
446
  # If it's a start token, skip it
443
447
  pass
448
+ elif event.endswith("/0"):
449
+ # This should capture the concept such as Visit/0, Discharge/0
450
+ pass
444
451
  else:
445
452
  try:
446
453
  concept_id = int(event)
cehrgpt/gpt_utils.py CHANGED
@@ -1,7 +1,12 @@
1
1
  import random
2
2
  import re
3
- from datetime import date, timedelta
4
- from typing import List, Sequence, Tuple
3
+ from datetime import date, datetime, timedelta, timezone
4
+ from typing import List, Optional, Sequence, Tuple, Union
5
+
6
+ import numpy as np
7
+ from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
8
+ from meds import death_code
9
+ from transformers.utils import logging
5
10
 
6
11
  from cehrgpt.cehrgpt_args import SamplingStrategy
7
12
  from cehrgpt.models.special_tokens import (
@@ -14,6 +19,7 @@ from cehrgpt.models.special_tokens import (
14
19
  MEDS_CODE_PATTERN = re.compile(r".*/.*")
15
20
  INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
16
21
  DEMOGRAPHIC_PROMPT_SIZE = 4
22
+ logger = logging.get_logger("transformers")
17
23
 
18
24
 
19
25
  class RandomSampleCache:
@@ -62,6 +68,68 @@ class RandomSampleCache:
62
68
  return self._cache.pop()
63
69
 
64
70
 
71
+ def construct_time_sequence(
72
+ concept_ids: List[str], epoch_times: Optional[List[Union[int, float]]] = None
73
+ ) -> List[float]:
74
+ if epoch_times is not None:
75
+ return epoch_times
76
+
77
+ if concept_ids[0].lower().startswith("year"):
78
+ year_str = concept_ids[0].split(":")[1]
79
+ else:
80
+ year_str = "1985"
81
+
82
+ datetime_cursor = datetime(
83
+ int(year_str), month=1, day=1, hour=0, minute=0, second=0
84
+ ).replace(tzinfo=timezone.utc)
85
+ epoch_times = []
86
+ for concept_id in concept_ids:
87
+ if is_att_token(concept_id):
88
+ att_days = extract_time_interval_in_days(concept_id)
89
+ datetime_cursor += timedelta(days=att_days)
90
+ epoch_times.append(datetime_cursor.timestamp())
91
+ return epoch_times
92
+
93
+
94
+ def construct_age_sequence(
95
+ concept_ids: List[str], ages: Optional[List[int]] = None
96
+ ) -> List[int]:
97
+ if ages is not None:
98
+ return ages
99
+ elif concept_ids[1].lower().startswith("age"):
100
+ age_str = concept_ids[1].split(":")[1]
101
+ assert age_str.isnumeric(), f"age_str: {age_str}"
102
+ ages = []
103
+ time_delta = 0
104
+ for concept_id in concept_ids:
105
+ if is_att_token(concept_id):
106
+ time_delta += extract_time_interval_in_days(concept_id)
107
+ ages.append(int(age_str) + time_delta // 365)
108
+ return ages
109
+ else:
110
+ logger.warning(
111
+ "The second token is not a valid age token. The first 4 tokens are: %s. "
112
+ "Trying to fall back to ages, but it is not valid either %s. "
113
+ "Fall back to a zero vector [0, 0, 0, ...., 0]",
114
+ concept_ids[:4],
115
+ ages,
116
+ )
117
+ return np.zeros_like(concept_ids, dtype=int).tolist()
118
+
119
+
120
+ def multiple_of_10(n: int) -> int:
121
+ return ((n // 10) + 1) * 10
122
+
123
+
124
+ def encode_demographics(
125
+ age: int, gender: int, race: int, max_age=200, max_gender=10, max_race=10
126
+ ) -> int:
127
+ assert 0 <= age < max_age, f"age: {age}"
128
+ assert 0 <= gender < max_gender, f"gender: {gender}"
129
+ assert 0 <= race < max_race, f"race: {race}"
130
+ return age + max_age * gender + max_age * max_gender * race
131
+
132
+
65
133
  def collect_demographic_prompts_at_visits(patient_history: List[str]):
66
134
  demographic_prompts_at_visits = []
67
135
  start_year, start_age, start_gender, start_race = patient_history[
@@ -156,7 +224,7 @@ def random_slice_gpt_sequence(concept_ids, max_seq_len):
156
224
  )
157
225
  ):
158
226
  current_token = concept_ids[i]
159
- if current_token == "VE":
227
+ if is_visit_end(current_token):
160
228
  random_end_index = i
161
229
  break
162
230
  return random_starting_index, random_end_index, demographic_tokens
@@ -198,6 +266,8 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
198
266
  def is_clinical_event(token: str, meds: bool = False) -> bool:
199
267
  if token.isnumeric():
200
268
  return True
269
+ if token in [DEATH_TOKEN, death_code]:
270
+ return True
201
271
  if meds:
202
272
  return bool(MEDS_CODE_PATTERN.match(token))
203
273
  return False
@@ -0,0 +1,27 @@
1
+ # From https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
2
+ # coding=utf-8
3
+
4
+ from __future__ import absolute_import, division, print_function
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import transformers.pytorch_utils
9
+
10
+
11
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
12
+ class RMSNorm(nn.Module):
13
+ def __init__(self, hidden_size, eps=1e-6):
14
+ """MistralRMSNorm is equivalent to T5LayerNorm."""
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(hidden_size))
17
+ self.variance_epsilon = eps
18
+
19
+ def forward(self, hidden_states):
20
+ input_dtype = hidden_states.dtype
21
+ hidden_states = hidden_states.to(torch.float32)
22
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
23
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
24
+ return self.weight * hidden_states.to(input_dtype)
25
+
26
+
27
+ transformers.pytorch_utils.ALL_LAYERNORM_LAYERS.extend([RMSNorm])
cehrgpt/models/config.py CHANGED
@@ -106,6 +106,8 @@ class CEHRGPTConfig(PretrainedConfig):
106
106
  n_head=12,
107
107
  n_inner=None,
108
108
  activation_function="gelu_new",
109
+ decoder_mlp="GPT2MLP",
110
+ mlp_bias=False,
109
111
  resid_pdrop=0.1,
110
112
  embd_pdrop=0.1,
111
113
  attn_pdrop=0.1,
@@ -124,7 +126,7 @@ class CEHRGPTConfig(PretrainedConfig):
124
126
  ve_token_id=None,
125
127
  scale_attn_by_inverse_layer_idx=False,
126
128
  reorder_and_upcast_attn=False,
127
- exclude_position_ids=False,
129
+ apply_rotary=False,
128
130
  include_values=False,
129
131
  value_vocab_size=None,
130
132
  include_ttv_prediction=False,
@@ -169,6 +171,8 @@ class CEHRGPTConfig(PretrainedConfig):
169
171
  self.n_head = n_head
170
172
  self.n_inner = n_inner
171
173
  self.activation_function = activation_function
174
+ self.decoder_mlp = decoder_mlp
175
+ self.mlp_bias = mlp_bias
172
176
  self.resid_pdrop = resid_pdrop
173
177
  self.embd_pdrop = embd_pdrop
174
178
  self.attn_pdrop = attn_pdrop
@@ -188,7 +192,7 @@ class CEHRGPTConfig(PretrainedConfig):
188
192
  self.eos_token_id = eos_token_id
189
193
  self.lab_token_ids = lab_token_ids
190
194
 
191
- self.exclude_position_ids = exclude_position_ids
195
+ self.apply_rotary = apply_rotary
192
196
  self.include_values = include_values
193
197
  self.value_vocab_size = value_vocab_size
194
198