cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import datetime
2
- from typing import Any, Dict
2
+ from typing import Any, Dict, Generator, Optional
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -8,8 +8,12 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
8
8
  INPATIENT_VISIT_TYPE_CODES,
9
9
  INPATIENT_VISIT_TYPES,
10
10
  DatasetMapping,
11
+ VisitObject,
12
+ get_value,
13
+ has_events_and_get_events,
11
14
  replace_escape_chars,
12
15
  )
16
+ from cehrbert.med_extension.schema_extension import Event
13
17
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
14
18
  from cehrbert_data.const.common import NA
15
19
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
@@ -32,7 +36,6 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
32
36
  def __init__(
33
37
  self,
34
38
  data_args: DataTrainingArguments,
35
- is_pretraining: bool = True,
36
39
  include_inpatient_hour_token: bool = True,
37
40
  ):
38
41
  self._time_token_function = get_att_function(data_args.att_function_type)
@@ -41,7 +44,6 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
41
44
  data_args.inpatient_att_function_type
42
45
  )
43
46
  self._include_demographic_prompt = data_args.include_demographic_prompt
44
- self._is_pretraining = is_pretraining
45
47
  self._include_inpatient_hour_token = include_inpatient_hour_token
46
48
 
47
49
  """
@@ -57,14 +59,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
57
59
  """
58
60
 
59
61
  def remove_columns(self):
60
- if self._is_pretraining:
61
- return ["visits", "birth_datetime", "index_date"]
62
- else:
63
- return [
64
- "visits",
65
- "birth_datetime",
66
- "visit_concept_ids",
67
- ]
62
+ return ["patient_id", "visits", "birth_datetime"]
68
63
 
69
64
  @staticmethod
70
65
  def _update_cehrgpt_record(
@@ -99,38 +94,45 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
99
94
  birth_datetime = birth_datetime.to_pydatetime()
100
95
  gender = record["gender"]
101
96
  race = record["race"]
97
+ visits = record["visits"]
98
+ # This indicates this is columnar format
99
+ if isinstance(visits, dict):
100
+ visits = sorted(self.convert_visit_columnar_to_python(visits))
101
+ else:
102
+ visits = sorted(visits, key=lambda _: get_value(_, "visit_start_datetime"))
102
103
 
103
104
  # Add the demographic tokens
104
- first_visit = record["visits"][0]
105
- year_str = f'year:{str(first_visit["visit_start_datetime"].year)}'
106
- age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}'
105
+ first_visit = visits[0]
106
+ first_visit_start_datetime: datetime.datetime = get_value(
107
+ first_visit, "visit_start_datetime"
108
+ )
109
+ year_str = f"year:{str(first_visit_start_datetime.year)}"
110
+ age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
107
111
  self._update_cehrgpt_record(cehrgpt_record, year_str)
108
112
  self._update_cehrgpt_record(cehrgpt_record, age_str)
109
113
  self._update_cehrgpt_record(cehrgpt_record, gender)
110
114
  self._update_cehrgpt_record(cehrgpt_record, race)
111
115
 
112
116
  # Use a data cursor to keep track of time
113
- date_cursor = None
114
-
115
- # Loop through all the visits excluding the first event containing the demographics
116
- for i, visit in enumerate(
117
- sorted(record["visits"], key=lambda e: e["visit_start_datetime"])
118
- ):
119
-
120
- events = visit["events"]
121
-
122
- # Skip this visit if the number measurements in the event is zero
123
- if events is None or len(events) == 0:
117
+ datetime_cursor: Optional[datetime.datetime] = None
118
+ visit: VisitObject
119
+ # Loop through all the visits
120
+ for i, visit in enumerate(visits):
121
+ events: Generator[Event, None, None] = get_value(visit, "events")
122
+ has_events, events = has_events_and_get_events(events)
123
+ if not has_events:
124
124
  continue
125
125
 
126
- visit_start_datetime = visit["visit_start_datetime"]
127
- time_delta = (
128
- (visit_start_datetime - date_cursor).days if date_cursor else None
126
+ visit_start_datetime: datetime.datetime = get_value(
127
+ visit, "visit_start_datetime"
128
+ )
129
+ # If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
130
+ visit_end_datetime: Optional[datetime.datetime] = get_value(
131
+ visit, "visit_end_datetime"
129
132
  )
130
- date_cursor = visit_start_datetime
131
133
 
132
134
  # We assume the first measurement to be the visit type of the current visit
133
- visit_type = visit["visit_type"]
135
+ visit_type = get_value(visit, "visit_type")
134
136
  is_er_or_inpatient = (
135
137
  visit_type in INPATIENT_VISIT_TYPES
136
138
  or visit_type in INPATIENT_VISIT_TYPE_CODES
@@ -138,21 +140,15 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
138
140
  )
139
141
 
140
142
  # Add artificial time tokens to the patient timeline if timedelta exists
141
- if time_delta is not None:
143
+ if datetime_cursor is not None:
144
+ time_delta = max((visit_start_datetime - datetime_cursor).days, 0)
142
145
  # This generates an artificial time token depending on the choice of the time token functions
143
146
  self._update_cehrgpt_record(
144
147
  cehrgpt_record,
145
148
  code=self._time_token_function(time_delta),
146
149
  )
147
150
 
148
- # Add the VS token to the patient timeline to mark the start of a visit
149
- relativedelta(visit["visit_start_datetime"], birth_datetime).years
150
- # Calculate the week number since the epoch time
151
- date = (
152
- visit["visit_start_datetime"]
153
- - datetime.datetime(year=1970, month=1, day=1)
154
- ).days // 7
155
-
151
+ datetime_cursor = visit_start_datetime
156
152
  # Add a [VS] token
157
153
  self._update_cehrgpt_record(
158
154
  cehrgpt_record,
@@ -163,11 +159,22 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
163
159
  cehrgpt_record,
164
160
  code=visit_type,
165
161
  )
162
+ # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
163
+ # with respect to the midnight of the day
164
+ if is_er_or_inpatient and self._include_inpatient_hour_token:
165
+ if datetime_cursor.hour > 0:
166
+ # This generates an artificial time token depending on the choice of the time token functions
167
+ self._update_cehrgpt_record(
168
+ cehrgpt_record,
169
+ code=f"i-H{datetime_cursor.hour}",
170
+ )
171
+
166
172
  # Keep track of the existing outpatient events, we don't want to add them again
167
- existing_outpatient_events = list()
173
+ existing_duplicate_events = list()
168
174
  for e in events:
169
175
  # If the event doesn't have a time stamp, we skip it
170
- if not e["time"]:
176
+ event_time: datetime.datetime = e["time"]
177
+ if not event_time:
171
178
  continue
172
179
 
173
180
  # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
@@ -181,40 +188,48 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
181
188
  is_numeric_type = int(numeric_value is not None)
182
189
  code = replace_escape_chars(e["code"])
183
190
 
191
+ # Create the event identity
192
+ event_identity = (
193
+ (event_time, code, text_value, unit)
194
+ if is_er_or_inpatient
195
+ else (event_time.date(), code, text_value, unit)
196
+ )
197
+
184
198
  # Add a medical token to the patient timeline
185
199
  # If this is an inpatient visit, we use the event time stamps to calculate age and date
186
200
  # because the patient can stay in the hospital for a period of time.
187
201
  if is_er_or_inpatient:
188
- # Calculate the week number since the epoch time
189
- date = (
190
- e["time"] - datetime.datetime(year=1970, month=1, day=1)
191
- ).days // 7
192
202
  # Calculate the time diff in days w.r.t the previous measurement
193
- meas_time_diff = (e["time"] - date_cursor).days
194
- # Update the date_cursor if the time diff between two neighboring measurements is greater than and
203
+ time_diff_days = (event_time - datetime_cursor).days
204
+ # Update the datetime_cursor if the time diff between two neighboring measurements is greater than and
195
205
  # equal to 1 day
196
- if meas_time_diff > 0:
197
- date_cursor = e["time"]
198
- if self._inpatient_time_token_function:
206
+ if self._inpatient_time_token_function and time_diff_days > 0:
207
+ # This generates an artificial time token depending on the choice of the time token functions
208
+ self._update_cehrgpt_record(
209
+ cehrgpt_record,
210
+ code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
211
+ )
212
+
213
+ if self._include_inpatient_hour_token:
214
+ # if the time difference in days is greater than 0, we calculate the hour interval
215
+ # with respect to the midnight of the day
216
+ time_diff_hours = (
217
+ event_time.hour
218
+ if time_diff_days > 0
219
+ else int(
220
+ (event_time - datetime_cursor).total_seconds() // 3600
221
+ )
222
+ )
223
+
224
+ if time_diff_hours > 0:
199
225
  # This generates an artificial time token depending on the choice of the time token functions
200
226
  self._update_cehrgpt_record(
201
227
  cehrgpt_record,
202
- code=f"i-{self._inpatient_time_token_function(meas_time_diff)}",
228
+ code=f"i-H{time_diff_hours}",
203
229
  )
204
- else:
205
- # For outpatient visits, we use the visit time stamp to calculate age and time because we assume
206
- # the outpatient visits start and end on the same day.
207
- # We check whether the date/code/value combination already exists in the existing events
208
- # If they exist, we do not add them to the patient timeline for outpatient visits.
209
- if (
210
- date,
211
- code,
212
- numeric_value,
213
- text_value,
214
- concept_value_mask,
215
- numeric_value,
216
- ) in existing_outpatient_events:
217
- continue
230
+
231
+ if event_identity in existing_duplicate_events:
232
+ continue
218
233
 
219
234
  self._update_cehrgpt_record(
220
235
  cehrgpt_record,
@@ -227,33 +242,27 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
227
242
  ),
228
243
  is_numeric_type=is_numeric_type,
229
244
  )
230
- existing_outpatient_events.append(
231
- (
232
- date,
233
- code,
234
- numeric_value,
235
- text_value,
236
- concept_value_mask,
237
- numeric_value,
238
- )
239
- )
245
+ existing_duplicate_events.append(event_identity)
246
+ # we only want to update the time stamp when data_cursor is less than the event time
247
+ if datetime_cursor < event_time or datetime_cursor is None:
248
+ datetime_cursor = event_time
249
+ # We need to bound the datetime_cursor if the current visit is an admission type of visit
250
+ # as the associated events could be generated after the visits are complete
251
+ if is_er_or_inpatient and visit_end_datetime is not None:
252
+ datetime_cursor = min(datetime_cursor, visit_end_datetime)
240
253
 
241
254
  # For inpatient or ER visits, we want to discharge_facility to the end of the visit
242
255
  if is_er_or_inpatient:
243
- # If visit_end_datetime is populated for the inpatient visit, we update the date_cursor
244
- visit_end_datetime = visit.get("visit_end_datetime", None)
245
- if visit_end_datetime:
246
- date_cursor = visit_end_datetime
256
+ # If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
257
+ if visit_end_datetime is not None:
258
+ datetime_cursor = visit_end_datetime
247
259
 
248
260
  if self._include_auxiliary_token:
249
261
  # Reuse the age and date calculated for the last event in the patient timeline for the discharge
250
262
  # facility event
251
- discharge_facility = (
252
- visit["discharge_facility"]
253
- if ("discharge_facility" in visit)
254
- and visit["discharge_facility"]
255
- else "0"
256
- )
263
+ discharge_facility = get_value(visit, "discharge_facility")
264
+ if not discharge_facility:
265
+ discharge_facility = "0"
257
266
 
258
267
  self._update_cehrgpt_record(
259
268
  cehrgpt_record,
@@ -273,11 +282,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
273
282
 
274
283
  # Add some count information for this sequence
275
284
  cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
276
- cehrgpt_record["num_of_visits"] = len(record["visits"])
285
+ cehrgpt_record["num_of_visits"] = len(visits)
277
286
 
278
- if "label" in record:
287
+ if record.get("index_date", None):
288
+ cehrgpt_record["index_date"] = record["index_date"]
289
+ if record.get("label", None):
279
290
  cehrgpt_record["label"] = record["label"]
280
- if "age_at_index" in record:
291
+ if record.get("age_at_index", None):
281
292
  cehrgpt_record["age_at_index"] = record["age_at_index"]
282
293
 
283
294
  return cehrgpt_record
@@ -0,0 +1,151 @@
1
+ from typing import Iterator, List, Optional
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Sampler
6
+ from transformers import logging
7
+
8
+ LOG = logging.get_logger("transformers")
9
+
10
+
11
+ class SamplePlacerHolder:
12
+ def __init__(self):
13
+ self.epoch = 0
14
+
15
+ def set_epoch(self, epoch):
16
+ self.epoch = epoch
17
+
18
+
19
+ class SamplePackingBatchSampler(Sampler[List[int]]):
20
+ """
21
+ A batch sampler that creates batches by packing samples together.
22
+
23
+ to maximize GPU utilization, ensuring the total tokens per batch
24
+ doesn't exceed max_tokens.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lengths: List[int],
30
+ max_tokens_per_batch: int,
31
+ max_position_embeddings: int,
32
+ num_replicas: Optional[int] = None,
33
+ rank: Optional[int] = None,
34
+ seed: int = 0,
35
+ drop_last: bool = False,
36
+ ):
37
+ """
38
+ Args:
39
+
40
+ lengths: List of sequence lengths for each sample
41
+ max_tokens: Maximum number of tokens in a batch
42
+ drop_last: Whether to drop the last incomplete batch
43
+ """
44
+ super().__init__()
45
+
46
+ if num_replicas is None:
47
+ if dist.is_available() and dist.is_initialized():
48
+ num_replicas = dist.get_world_size()
49
+ LOG.info(
50
+ "torch.distributed is initialized and there are %s of replicas",
51
+ num_replicas,
52
+ )
53
+ else:
54
+ num_replicas = 1
55
+ LOG.info(
56
+ "torch.dist is not initialized and therefore default to 1 for num_replicas"
57
+ )
58
+
59
+ if rank is None:
60
+ if dist.is_available() and dist.is_initialized():
61
+ rank = dist.get_rank()
62
+ LOG.info(
63
+ "torch.distributed is initialized and the current rank is %s", rank
64
+ )
65
+ else:
66
+ rank = 0
67
+ LOG.info(
68
+ "torch.distributed is not initialized and therefore default to 0 for rank"
69
+ )
70
+
71
+ if not (0 <= rank < num_replicas):
72
+ raise ValueError(
73
+ f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
74
+ )
75
+
76
+ self.lengths = lengths
77
+ self.max_tokens_per_batch = max_tokens_per_batch
78
+ self.max_position_embeddings = max_position_embeddings
79
+ self.num_replicas = num_replicas
80
+ self.rank = rank
81
+ self.seed = seed
82
+ self.drop_last = drop_last
83
+ # Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
84
+ # http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
85
+ # the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
86
+ # which will call batch_sampler.sample.set_epoch
87
+ self.sampler = SamplePlacerHolder()
88
+
89
+ def __iter__(self) -> Iterator[List[int]]:
90
+
91
+ # deterministically shuffle based on epoch and seed
92
+ g = torch.Generator()
93
+ g.manual_seed(self.seed + self.sampler.epoch)
94
+ indices = torch.randperm(len(self.lengths), generator=g).tolist()
95
+
96
+ # Partition indices for this rank
97
+ indices = indices[self.rank :: self.num_replicas]
98
+
99
+ batch = []
100
+ current_batch_tokens = 0
101
+
102
+ for idx in indices:
103
+ # We take the minimum of the two because each sequence will be truncated to fit
104
+ # the context window of the model
105
+ sample_length = min(self.lengths[idx], self.max_position_embeddings)
106
+ # If adding this sample would exceed max_tokens_per_batch, yield the current batch
107
+ if (
108
+ current_batch_tokens + sample_length + 2 > self.max_tokens_per_batch
109
+ and batch
110
+ ):
111
+ yield batch
112
+ batch = []
113
+ current_batch_tokens = 0
114
+
115
+ # Add the sample to the current batch
116
+ batch.append(idx)
117
+ # plus extract one for the [END] and [PAD] tokens to separate samples
118
+ current_batch_tokens += sample_length + 2
119
+
120
+ # Yield the last batch if it's not empty and we're not dropping it
121
+ if batch and not self.drop_last:
122
+ yield batch
123
+
124
+ def __len__(self) -> int:
125
+ """
126
+ Estimates the number of batches that will be generated.
127
+
128
+ This is an approximation since the exact number depends on the specific
129
+ sequence lengths and their order.
130
+ """
131
+ if len(self.lengths) == 0:
132
+ return 0
133
+
134
+ # We need to truncate the lengths due to the context window limit imposed by the model
135
+ truncated_lengths = [
136
+ min(self.max_position_embeddings, length + 2) for length in self.lengths
137
+ ]
138
+
139
+ # Calculate average sequence length
140
+ avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
141
+
142
+ # Estimate average number of sequences per batch
143
+ seqs_per_batch = self.max_tokens_per_batch // avg_seq_length
144
+
145
+ # Estimate total number of batches
146
+ if self.drop_last:
147
+ # If dropping last incomplete batch
148
+ return len(truncated_lengths) // seqs_per_batch * self.num_replicas
149
+ else:
150
+ # If keeping last incomplete batch, ensure at least 1 batch
151
+ return max(1, len(truncated_lengths) // seqs_per_batch) * self.num_replicas
@@ -93,9 +93,9 @@ def generate_single_batch(
93
93
  temperature=temperature,
94
94
  top_p=top_p,
95
95
  top_k=top_k,
96
- bos_token_id=tokenizer.end_token_id,
97
- eos_token_id=tokenizer.end_token_id,
98
- pad_token_id=tokenizer.pad_token_id,
96
+ bos_token_id=model.generation_config.bos_token_id,
97
+ eos_token_id=model.generation_config.eos_token_id,
98
+ pad_token_id=model.generation_config.pad_token_id,
99
99
  do_sample=True,
100
100
  use_cache=True,
101
101
  return_dict_in_generate=True,
@@ -150,15 +150,11 @@ def main(args):
150
150
  attn_implementation=(
151
151
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
152
152
  ),
153
- torch_dtype=(
154
- torch.bfloat16
155
- if is_flash_attn_2_available() and args.use_bfloat16
156
- else torch.float32
157
- ),
158
153
  )
159
154
  .eval()
160
155
  .to(device)
161
156
  )
157
+
162
158
  cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
163
159
  cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
164
160
  cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
@@ -192,6 +188,7 @@ def main(args):
192
188
  LOG.info(f"Top P {args.top_p}")
193
189
  LOG.info(f"Top K {args.top_k}")
194
190
  LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
191
+ LOG.info(f"MEDS format: {args.meds_format}")
195
192
 
196
193
  dataset = load_parquet_as_dataset(args.demographic_data_path)
197
194
  total_rows = len(dataset)
@@ -199,6 +196,7 @@ def main(args):
199
196
  num_of_batches = args.num_of_patients // args.batch_size + 1
200
197
  sequence_to_flush = []
201
198
  current_person_id = 1
199
+ prompt_size = 2 if args.meds_format else START_TOKEN_SIZE
202
200
  for i in range(num_of_batches):
203
201
  LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
204
202
 
@@ -215,7 +213,7 @@ def main(args):
215
213
  <= max_seq_allowed
216
214
  ):
217
215
  random_prompts.append(
218
- cehrgpt_tokenizer.encode(row["concept_ids"][:START_TOKEN_SIZE])
216
+ cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
219
217
  )
220
218
  iter += 1
221
219
  if not random_prompts and iter > 10:
@@ -326,6 +324,11 @@ def create_arg_parser():
326
324
  dest="drop_long_sequences",
327
325
  action="store_true",
328
326
  )
327
+ base_arg_parser.add_argument(
328
+ "--meds_format",
329
+ dest="meds_format",
330
+ action="store_true",
331
+ )
329
332
  return base_arg_parser
330
333
 
331
334
 
cehrgpt/models/config.py CHANGED
@@ -133,14 +133,17 @@ class CEHRGPTConfig(PretrainedConfig):
133
133
  n_pretrained_embeddings_layers=2,
134
134
  pretrained_embedding_dim=768,
135
135
  pretrained_token_ids: List[int] = None,
136
+ next_token_prediction_loss_weight=1.0,
136
137
  time_token_loss_weight=1.0,
137
138
  time_to_visit_loss_weight=1.0,
138
139
  causal_sfm=False,
139
140
  demographics_size=4,
140
141
  lab_token_penalty=False,
141
142
  lab_token_loss_weight=0.9,
143
+ value_prediction_loss_weight=1.0,
142
144
  entropy_penalty=False,
143
145
  entropy_penalty_alpha=0.01,
146
+ sample_packing_max_positions=None,
144
147
  **kwargs,
145
148
  ):
146
149
  if token_to_time_token_mapping is None:
@@ -150,6 +153,11 @@ class CEHRGPTConfig(PretrainedConfig):
150
153
  self.vocab_size = vocab_size
151
154
  self.time_token_vocab_size = time_token_vocab_size
152
155
  self.n_positions = n_positions
156
+ self.sample_packing_max_positions = (
157
+ sample_packing_max_positions
158
+ if sample_packing_max_positions
159
+ else n_positions
160
+ )
153
161
  self.n_embd = n_embd
154
162
  self.n_layer = n_layer
155
163
  self.n_head = n_head
@@ -178,6 +186,7 @@ class CEHRGPTConfig(PretrainedConfig):
178
186
  self.include_values = include_values
179
187
  self.value_vocab_size = value_vocab_size
180
188
 
189
+ self.next_token_prediction_loss_weight = next_token_prediction_loss_weight
181
190
  self.include_ttv_prediction = include_ttv_prediction
182
191
  self.use_sub_time_tokenization = use_sub_time_tokenization
183
192
  self._token_to_time_token_mapping = token_to_time_token_mapping
@@ -195,6 +204,7 @@ class CEHRGPTConfig(PretrainedConfig):
195
204
  self.lab_token_loss_weight = lab_token_loss_weight
196
205
  self.entropy_penalty = entropy_penalty
197
206
  self.entropy_penalty_alpha = entropy_penalty_alpha
207
+ self.value_prediction_loss_weight = value_prediction_loss_weight
198
208
 
199
209
  kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
200
210