cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/generation/omop_converter_batch.py +3 -0
  7. cehrgpt/models/config.py +10 -0
  8. cehrgpt/models/hf_cehrgpt.py +244 -73
  9. cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
  10. cehrgpt/runners/data_utils.py +243 -0
  11. cehrgpt/runners/gpt_runner_util.py +0 -10
  12. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
  13. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
  14. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
  15. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  16. cehrgpt/runners/sample_packing_trainer.py +168 -0
  17. cehrgpt/simulations/__init__.py +0 -0
  18. cehrgpt/simulations/generate_plots.py +95 -0
  19. cehrgpt/simulations/run_simulation.sh +24 -0
  20. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  21. cehrgpt/simulations/time_token_simulation.py +177 -0
  22. cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
  23. cehrgpt/tools/linear_prob/__init__.py +0 -0
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  25. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  26. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
  27. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
  28. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  29. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  30. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,23 @@
1
1
  import datetime
2
- from typing import Any, Dict
2
+ from typing import Any, Dict, Generator, Optional
3
3
 
4
4
  import numpy as np
5
- from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
5
+ import pandas as pd
6
+ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
7
+ ED_VISIT_TYPE_CODES,
8
+ INPATIENT_VISIT_TYPE_CODES,
9
+ INPATIENT_VISIT_TYPES,
10
+ DatasetMapping,
11
+ VisitObject,
12
+ get_value,
13
+ has_events_and_get_events,
14
+ replace_escape_chars,
15
+ )
16
+ from cehrbert.med_extension.schema_extension import Event
17
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
18
+ from cehrbert_data.const.common import NA
19
+ from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
20
+ from dateutil.relativedelta import relativedelta
6
21
 
7
22
  from cehrgpt.models.tokenization_hf_cehrgpt import (
8
23
  NONE_BIN,
@@ -17,6 +32,268 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
17
32
  ).timestamp()
18
33
 
19
34
 
35
+ class MedToCehrGPTDatasetMapping(DatasetMapping):
36
+ def __init__(
37
+ self,
38
+ data_args: DataTrainingArguments,
39
+ include_inpatient_hour_token: bool = True,
40
+ ):
41
+ self._time_token_function = get_att_function(data_args.att_function_type)
42
+ self._include_auxiliary_token = data_args.include_auxiliary_token
43
+ self._inpatient_time_token_function = get_att_function(
44
+ data_args.inpatient_att_function_type
45
+ )
46
+ self._include_demographic_prompt = data_args.include_demographic_prompt
47
+ self._include_inpatient_hour_token = include_inpatient_hour_token
48
+
49
+ """
50
+ This mapping function converts the MED (https://github.com/Medical-Event-Data-Standard/meds/tree/main) extension
51
+ to the CehrGPT format. We make several assumptions
52
+ - The first event contains the demographic information
53
+ - From the second event onward
54
+ - the time of the event is visit_start_datetime.
55
+ - the first measurement contains the code indicating a standard OMOP Visit concept_id (e.g. 9201, 9202)
56
+ - in case of inpatient visits, the last measurement is assumed to
57
+ contain the standard OMOP concept id for discharge facilities (e.g 8536)
58
+ - in case of inpatient visits, datetime_value of the last measurement stores visit_end_datetime
59
+ """
60
+
61
+ def remove_columns(self):
62
+ return ["patient_id", "visits", "birth_datetime"]
63
+
64
+ @staticmethod
65
+ def _update_cehrgpt_record(
66
+ cehrgpt_record: Dict[str, Any],
67
+ code: str,
68
+ concept_value_mask: int = 0,
69
+ number_as_value: float = 0.0,
70
+ concept_as_value: str = "0",
71
+ is_numeric_type: int = 0,
72
+ unit: str = NA,
73
+ ) -> None:
74
+ cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
75
+ cehrgpt_record["concept_value_masks"].append(concept_value_mask)
76
+ cehrgpt_record["number_as_values"].append(number_as_value)
77
+ cehrgpt_record["concept_as_values"].append(concept_as_value)
78
+ cehrgpt_record["units"].append(unit)
79
+ cehrgpt_record["is_numeric_types"].append(is_numeric_type)
80
+
81
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
82
+ cehrgpt_record = {
83
+ "person_id": record["patient_id"],
84
+ "concept_ids": [],
85
+ "concept_value_masks": [],
86
+ "number_as_values": [],
87
+ "concept_as_values": [],
88
+ "units": [],
89
+ "is_numeric_types": [],
90
+ }
91
+ # Extract the demographic information
92
+ birth_datetime = record["birth_datetime"]
93
+ if isinstance(birth_datetime, pd.Timestamp):
94
+ birth_datetime = birth_datetime.to_pydatetime()
95
+ gender = record["gender"]
96
+ race = record["race"]
97
+ visits = record["visits"]
98
+ # This indicates this is columnar format
99
+ if isinstance(visits, dict):
100
+ visits = sorted(self.convert_visit_columnar_to_python(visits))
101
+ else:
102
+ visits = sorted(visits, key=lambda _: get_value(_, "visit_start_datetime"))
103
+
104
+ # Add the demographic tokens
105
+ first_visit = visits[0]
106
+ first_visit_start_datetime: datetime.datetime = get_value(
107
+ first_visit, "visit_start_datetime"
108
+ )
109
+ year_str = f"year:{str(first_visit_start_datetime.year)}"
110
+ age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
111
+ self._update_cehrgpt_record(cehrgpt_record, year_str)
112
+ self._update_cehrgpt_record(cehrgpt_record, age_str)
113
+ self._update_cehrgpt_record(cehrgpt_record, gender)
114
+ self._update_cehrgpt_record(cehrgpt_record, race)
115
+
116
+ # Use a data cursor to keep track of time
117
+ datetime_cursor: Optional[datetime.datetime] = None
118
+ visit: VisitObject
119
+ # Loop through all the visits
120
+ for i, visit in enumerate(visits):
121
+ events: Generator[Event, None, None] = get_value(visit, "events")
122
+ has_events, events = has_events_and_get_events(events)
123
+ if not has_events:
124
+ continue
125
+
126
+ visit_start_datetime: datetime.datetime = get_value(
127
+ visit, "visit_start_datetime"
128
+ )
129
+ # If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
130
+ visit_end_datetime: Optional[datetime.datetime] = get_value(
131
+ visit, "visit_end_datetime"
132
+ )
133
+
134
+ # We assume the first measurement to be the visit type of the current visit
135
+ visit_type = get_value(visit, "visit_type")
136
+ is_er_or_inpatient = (
137
+ visit_type in INPATIENT_VISIT_TYPES
138
+ or visit_type in INPATIENT_VISIT_TYPE_CODES
139
+ or visit_type in ED_VISIT_TYPE_CODES
140
+ )
141
+
142
+ # Add artificial time tokens to the patient timeline if timedelta exists
143
+ if datetime_cursor is not None:
144
+ time_delta = max((visit_start_datetime - datetime_cursor).days, 0)
145
+ # This generates an artificial time token depending on the choice of the time token functions
146
+ self._update_cehrgpt_record(
147
+ cehrgpt_record,
148
+ code=self._time_token_function(time_delta),
149
+ )
150
+
151
+ datetime_cursor = visit_start_datetime
152
+ # Add a [VS] token
153
+ self._update_cehrgpt_record(
154
+ cehrgpt_record,
155
+ code="[VS]",
156
+ )
157
+ # Add a visit type token
158
+ self._update_cehrgpt_record(
159
+ cehrgpt_record,
160
+ code=visit_type,
161
+ )
162
+ # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
163
+ # with respect to the midnight of the day
164
+ if is_er_or_inpatient and self._include_inpatient_hour_token:
165
+ if datetime_cursor.hour > 0:
166
+ # This generates an artificial time token depending on the choice of the time token functions
167
+ self._update_cehrgpt_record(
168
+ cehrgpt_record,
169
+ code=f"i-H{datetime_cursor.hour}",
170
+ )
171
+
172
+ # Keep track of the existing outpatient events, we don't want to add them again
173
+ existing_duplicate_events = list()
174
+ for e in events:
175
+ # If the event doesn't have a time stamp, we skip it
176
+ event_time: datetime.datetime = e["time"]
177
+ if not event_time:
178
+ continue
179
+
180
+ # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
181
+ numeric_value = e.get("numeric_value", None)
182
+ text_value = e.get("text_value", None)
183
+ # The unit might be populated with a None value
184
+ unit = e.get("unit", NA) if e.get("unit", NA) else NA
185
+ concept_value_mask = int(
186
+ numeric_value is not None or text_value is not None
187
+ )
188
+ is_numeric_type = int(numeric_value is not None)
189
+ code = replace_escape_chars(e["code"])
190
+
191
+ # Create the event identity
192
+ event_identity = (
193
+ (event_time, code, text_value, unit)
194
+ if is_er_or_inpatient
195
+ else (event_time.date(), code, text_value, unit)
196
+ )
197
+
198
+ # Add a medical token to the patient timeline
199
+ # If this is an inpatient visit, we use the event time stamps to calculate age and date
200
+ # because the patient can stay in the hospital for a period of time.
201
+ if is_er_or_inpatient:
202
+ # Calculate the time diff in days w.r.t the previous measurement
203
+ time_diff_days = (event_time - datetime_cursor).days
204
+ # Update the datetime_cursor if the time diff between two neighboring measurements is greater than and
205
+ # equal to 1 day
206
+ if self._inpatient_time_token_function and time_diff_days > 0:
207
+ # This generates an artificial time token depending on the choice of the time token functions
208
+ self._update_cehrgpt_record(
209
+ cehrgpt_record,
210
+ code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
211
+ )
212
+
213
+ if self._include_inpatient_hour_token:
214
+ # if the time difference in days is greater than 0, we calculate the hour interval
215
+ # with respect to the midnight of the day
216
+ time_diff_hours = (
217
+ event_time.hour
218
+ if time_diff_days > 0
219
+ else int(
220
+ (event_time - datetime_cursor).total_seconds() // 3600
221
+ )
222
+ )
223
+
224
+ if time_diff_hours > 0:
225
+ # This generates an artificial time token depending on the choice of the time token functions
226
+ self._update_cehrgpt_record(
227
+ cehrgpt_record,
228
+ code=f"i-H{time_diff_hours}",
229
+ )
230
+
231
+ if event_identity in existing_duplicate_events:
232
+ continue
233
+
234
+ self._update_cehrgpt_record(
235
+ cehrgpt_record,
236
+ code=code,
237
+ concept_value_mask=concept_value_mask,
238
+ unit=unit,
239
+ number_as_value=numeric_value if numeric_value else 0.0,
240
+ concept_as_value=(
241
+ replace_escape_chars(text_value) if text_value else "0"
242
+ ),
243
+ is_numeric_type=is_numeric_type,
244
+ )
245
+ existing_duplicate_events.append(event_identity)
246
+ # we only want to update the time stamp when data_cursor is less than the event time
247
+ if datetime_cursor < event_time or datetime_cursor is None:
248
+ datetime_cursor = event_time
249
+ # We need to bound the datetime_cursor if the current visit is an admission type of visit
250
+ # as the associated events could be generated after the visits are complete
251
+ if is_er_or_inpatient and visit_end_datetime is not None:
252
+ datetime_cursor = min(datetime_cursor, visit_end_datetime)
253
+
254
+ # For inpatient or ER visits, we want to discharge_facility to the end of the visit
255
+ if is_er_or_inpatient:
256
+ # If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
257
+ if visit_end_datetime is not None:
258
+ datetime_cursor = visit_end_datetime
259
+
260
+ if self._include_auxiliary_token:
261
+ # Reuse the age and date calculated for the last event in the patient timeline for the discharge
262
+ # facility event
263
+ discharge_facility = get_value(visit, "discharge_facility")
264
+ if not discharge_facility:
265
+ discharge_facility = "0"
266
+
267
+ self._update_cehrgpt_record(
268
+ cehrgpt_record,
269
+ code=discharge_facility,
270
+ )
271
+
272
+ # Reuse the age and date calculated for the last event in the patient timeline
273
+ self._update_cehrgpt_record(
274
+ cehrgpt_record,
275
+ code="[VE]",
276
+ )
277
+
278
+ # Generate the orders of the concepts that the cehrbert dataset mapping function expects
279
+ cehrgpt_record["orders"] = list(
280
+ range(1, len(cehrgpt_record["concept_ids"]) + 1)
281
+ )
282
+
283
+ # Add some count information for this sequence
284
+ cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
285
+ cehrgpt_record["num_of_visits"] = len(visits)
286
+
287
+ if record.get("index_date", None):
288
+ cehrgpt_record["index_date"] = record["index_date"]
289
+ if record.get("label", None):
290
+ cehrgpt_record["label"] = record["label"]
291
+ if record.get("age_at_index", None):
292
+ cehrgpt_record["age_at_index"] = record["age_at_index"]
293
+
294
+ return cehrgpt_record
295
+
296
+
20
297
  class HFCehrGptTokenizationMapping(DatasetMapping):
21
298
  def __init__(
22
299
  self,
@@ -0,0 +1,151 @@
1
+ from typing import Iterator, List, Optional
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Sampler
6
+ from transformers import logging
7
+
8
+ LOG = logging.get_logger("transformers")
9
+
10
+
11
+ class SamplePlacerHolder:
12
+ def __init__(self):
13
+ self.epoch = 0
14
+
15
+ def set_epoch(self, epoch):
16
+ self.epoch = epoch
17
+
18
+
19
+ class SamplePackingBatchSampler(Sampler[List[int]]):
20
+ """
21
+ A batch sampler that creates batches by packing samples together.
22
+
23
+ to maximize GPU utilization, ensuring the total tokens per batch
24
+ doesn't exceed max_tokens.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lengths: List[int],
30
+ max_tokens_per_batch: int,
31
+ max_position_embeddings: int,
32
+ num_replicas: Optional[int] = None,
33
+ rank: Optional[int] = None,
34
+ seed: int = 0,
35
+ drop_last: bool = False,
36
+ ):
37
+ """
38
+ Args:
39
+
40
+ lengths: List of sequence lengths for each sample
41
+ max_tokens: Maximum number of tokens in a batch
42
+ drop_last: Whether to drop the last incomplete batch
43
+ """
44
+ super().__init__()
45
+
46
+ if num_replicas is None:
47
+ if dist.is_available() and dist.is_initialized():
48
+ num_replicas = dist.get_world_size()
49
+ LOG.info(
50
+ "torch.distributed is initialized and there are %s of replicas",
51
+ num_replicas,
52
+ )
53
+ else:
54
+ num_replicas = 1
55
+ LOG.info(
56
+ "torch.dist is not initialized and therefore default to 1 for num_replicas"
57
+ )
58
+
59
+ if rank is None:
60
+ if dist.is_available() and dist.is_initialized():
61
+ rank = dist.get_rank()
62
+ LOG.info(
63
+ "torch.distributed is initialized and the current rank is %s", rank
64
+ )
65
+ else:
66
+ rank = 0
67
+ LOG.info(
68
+ "torch.distributed is not initialized and therefore default to 0 for rank"
69
+ )
70
+
71
+ if not (0 <= rank < num_replicas):
72
+ raise ValueError(
73
+ f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
74
+ )
75
+
76
+ self.lengths = lengths
77
+ self.max_tokens_per_batch = max_tokens_per_batch
78
+ self.max_position_embeddings = max_position_embeddings
79
+ self.num_replicas = num_replicas
80
+ self.rank = rank
81
+ self.seed = seed
82
+ self.drop_last = drop_last
83
+ # Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
84
+ # http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
85
+ # the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
86
+ # which will call batch_sampler.sample.set_epoch
87
+ self.sampler = SamplePlacerHolder()
88
+
89
+ def __iter__(self) -> Iterator[List[int]]:
90
+
91
+ # deterministically shuffle based on epoch and seed
92
+ g = torch.Generator()
93
+ g.manual_seed(self.seed + self.sampler.epoch)
94
+ indices = torch.randperm(len(self.lengths), generator=g).tolist()
95
+
96
+ # Partition indices for this rank
97
+ indices = indices[self.rank :: self.num_replicas]
98
+
99
+ batch = []
100
+ current_batch_tokens = 0
101
+
102
+ for idx in indices:
103
+ # We take the minimum of the two because each sequence will be truncated to fit
104
+ # the context window of the model
105
+ sample_length = min(self.lengths[idx], self.max_position_embeddings)
106
+ # If adding this sample would exceed max_tokens_per_batch, yield the current batch
107
+ if (
108
+ current_batch_tokens + sample_length + 2 > self.max_tokens_per_batch
109
+ and batch
110
+ ):
111
+ yield batch
112
+ batch = []
113
+ current_batch_tokens = 0
114
+
115
+ # Add the sample to the current batch
116
+ batch.append(idx)
117
+ # plus extract one for the [END] and [PAD] tokens to separate samples
118
+ current_batch_tokens += sample_length + 2
119
+
120
+ # Yield the last batch if it's not empty and we're not dropping it
121
+ if batch and not self.drop_last:
122
+ yield batch
123
+
124
+ def __len__(self) -> int:
125
+ """
126
+ Estimates the number of batches that will be generated.
127
+
128
+ This is an approximation since the exact number depends on the specific
129
+ sequence lengths and their order.
130
+ """
131
+ if len(self.lengths) == 0:
132
+ return 0
133
+
134
+ # We need to truncate the lengths due to the context window limit imposed by the model
135
+ truncated_lengths = [
136
+ min(self.max_position_embeddings, length + 2) for length in self.lengths
137
+ ]
138
+
139
+ # Calculate average sequence length
140
+ avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
141
+
142
+ # Estimate average number of sequences per batch
143
+ seqs_per_batch = self.max_tokens_per_batch // avg_seq_length
144
+
145
+ # Estimate total number of batches
146
+ if self.drop_last:
147
+ # If dropping last incomplete batch
148
+ return len(truncated_lengths) // seqs_per_batch * self.num_replicas
149
+ else:
150
+ # If keeping last incomplete batch, ensure at least 1 batch
151
+ return max(1, len(truncated_lengths) // seqs_per_batch) * self.num_replicas
@@ -93,9 +93,9 @@ def generate_single_batch(
93
93
  temperature=temperature,
94
94
  top_p=top_p,
95
95
  top_k=top_k,
96
- bos_token_id=tokenizer.end_token_id,
97
- eos_token_id=tokenizer.end_token_id,
98
- pad_token_id=tokenizer.pad_token_id,
96
+ bos_token_id=model.generation_config.bos_token_id,
97
+ eos_token_id=model.generation_config.eos_token_id,
98
+ pad_token_id=model.generation_config.pad_token_id,
99
99
  do_sample=True,
100
100
  use_cache=True,
101
101
  return_dict_in_generate=True,
@@ -150,15 +150,11 @@ def main(args):
150
150
  attn_implementation=(
151
151
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
152
152
  ),
153
- torch_dtype=(
154
- torch.bfloat16
155
- if is_flash_attn_2_available() and args.use_bfloat16
156
- else torch.float32
157
- ),
158
153
  )
159
154
  .eval()
160
155
  .to(device)
161
156
  )
157
+
162
158
  cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
163
159
  cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
164
160
  cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
@@ -192,6 +188,7 @@ def main(args):
192
188
  LOG.info(f"Top P {args.top_p}")
193
189
  LOG.info(f"Top K {args.top_k}")
194
190
  LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
191
+ LOG.info(f"MEDS format: {args.meds_format}")
195
192
 
196
193
  dataset = load_parquet_as_dataset(args.demographic_data_path)
197
194
  total_rows = len(dataset)
@@ -199,6 +196,7 @@ def main(args):
199
196
  num_of_batches = args.num_of_patients // args.batch_size + 1
200
197
  sequence_to_flush = []
201
198
  current_person_id = 1
199
+ prompt_size = 2 if args.meds_format else START_TOKEN_SIZE
202
200
  for i in range(num_of_batches):
203
201
  LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
204
202
 
@@ -215,7 +213,7 @@ def main(args):
215
213
  <= max_seq_allowed
216
214
  ):
217
215
  random_prompts.append(
218
- cehrgpt_tokenizer.encode(row["concept_ids"][:START_TOKEN_SIZE])
216
+ cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
219
217
  )
220
218
  iter += 1
221
219
  if not random_prompts and iter > 10:
@@ -326,6 +324,11 @@ def create_arg_parser():
326
324
  dest="drop_long_sequences",
327
325
  action="store_true",
328
326
  )
327
+ base_arg_parser.add_argument(
328
+ "--meds_format",
329
+ dest="meds_format",
330
+ action="store_true",
331
+ )
329
332
  return base_arg_parser
330
333
 
331
334
 
@@ -35,6 +35,7 @@ from cehrgpt.models.tokenization_hf_cehrgpt import END_TOKEN
35
35
  # TODO: move these to cehrbert_data
36
36
  STOP_TOKENS = ["VE", "[VE]", END_TOKEN]
37
37
 
38
+ OOV = "[OOV]"
38
39
  CURRENT_PATH = Path(__file__).parent
39
40
  START_TOKEN_SIZE = 4
40
41
  ATT_TIME_TOKENS = generate_artificial_time_tokens()
@@ -297,6 +298,8 @@ def gpt_to_omop_converter_batch(
297
298
  inpatient_visit_indicator = False
298
299
 
299
300
  for event_idx, event in enumerate(clinical_events, 0):
301
+ if event == OOV:
302
+ continue
300
303
  # For bad sequences, we don't proceed further and break from the for loop
301
304
  if bad_sequence:
302
305
  break
cehrgpt/models/config.py CHANGED
@@ -133,14 +133,17 @@ class CEHRGPTConfig(PretrainedConfig):
133
133
  n_pretrained_embeddings_layers=2,
134
134
  pretrained_embedding_dim=768,
135
135
  pretrained_token_ids: List[int] = None,
136
+ next_token_prediction_loss_weight=1.0,
136
137
  time_token_loss_weight=1.0,
137
138
  time_to_visit_loss_weight=1.0,
138
139
  causal_sfm=False,
139
140
  demographics_size=4,
140
141
  lab_token_penalty=False,
141
142
  lab_token_loss_weight=0.9,
143
+ value_prediction_loss_weight=1.0,
142
144
  entropy_penalty=False,
143
145
  entropy_penalty_alpha=0.01,
146
+ sample_packing_max_positions=None,
144
147
  **kwargs,
145
148
  ):
146
149
  if token_to_time_token_mapping is None:
@@ -150,6 +153,11 @@ class CEHRGPTConfig(PretrainedConfig):
150
153
  self.vocab_size = vocab_size
151
154
  self.time_token_vocab_size = time_token_vocab_size
152
155
  self.n_positions = n_positions
156
+ self.sample_packing_max_positions = (
157
+ sample_packing_max_positions
158
+ if sample_packing_max_positions
159
+ else n_positions
160
+ )
153
161
  self.n_embd = n_embd
154
162
  self.n_layer = n_layer
155
163
  self.n_head = n_head
@@ -178,6 +186,7 @@ class CEHRGPTConfig(PretrainedConfig):
178
186
  self.include_values = include_values
179
187
  self.value_vocab_size = value_vocab_size
180
188
 
189
+ self.next_token_prediction_loss_weight = next_token_prediction_loss_weight
181
190
  self.include_ttv_prediction = include_ttv_prediction
182
191
  self.use_sub_time_tokenization = use_sub_time_tokenization
183
192
  self._token_to_time_token_mapping = token_to_time_token_mapping
@@ -195,6 +204,7 @@ class CEHRGPTConfig(PretrainedConfig):
195
204
  self.lab_token_loss_weight = lab_token_loss_weight
196
205
  self.entropy_penalty = entropy_penalty
197
206
  self.entropy_penalty_alpha = entropy_penalty_alpha
207
+ self.value_prediction_loss_weight = value_prediction_loss_weight
198
208
 
199
209
  kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
200
210