cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
1
+ from typing import Iterator, List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.utils.data import Sampler
7
+ from transformers import logging
8
+
9
+ LOG = logging.get_logger("transformers")
10
+
11
+
12
+ class SamplePlacerHolder:
13
+ def __init__(self):
14
+ self.epoch = 0
15
+
16
+ def set_epoch(self, epoch):
17
+ self.epoch = epoch
18
+
19
+
20
+ class SamplePackingBatchSampler(Sampler[List[int]]):
21
+ """
22
+ A batch sampler that creates batches by packing samples together.
23
+
24
+ to maximize GPU utilization, ensuring the total tokens per batch
25
+ doesn't exceed max_tokens.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ lengths: List[int],
31
+ max_tokens_per_batch: int,
32
+ max_position_embeddings: int,
33
+ num_replicas: Optional[int] = None,
34
+ rank: Optional[int] = None,
35
+ seed: int = 0,
36
+ drop_last: bool = False,
37
+ negative_sampling_probability: Optional[float] = None,
38
+ labels: Optional[List[int]] = None,
39
+ ):
40
+ """
41
+ Args:
42
+
43
+ lengths: List of sequence lengths for each sample
44
+ max_tokens: Maximum number of tokens in a batch
45
+ drop_last: Whether to drop the last incomplete batch
46
+ """
47
+ super().__init__()
48
+
49
+ if num_replicas is None:
50
+ if dist.is_available() and dist.is_initialized():
51
+ num_replicas = dist.get_world_size()
52
+ LOG.info(
53
+ "torch.distributed is initialized and there are %s of replicas",
54
+ num_replicas,
55
+ )
56
+ else:
57
+ num_replicas = 1
58
+ LOG.info(
59
+ "torch.dist is not initialized and therefore default to 1 for num_replicas"
60
+ )
61
+
62
+ if rank is None:
63
+ if dist.is_available() and dist.is_initialized():
64
+ rank = dist.get_rank()
65
+ LOG.info(
66
+ "torch.distributed is initialized and the current rank is %s", rank
67
+ )
68
+ else:
69
+ rank = 0
70
+ LOG.info(
71
+ "torch.distributed is not initialized and therefore default to 0 for rank"
72
+ )
73
+
74
+ if not (0 <= rank < num_replicas):
75
+ raise ValueError(
76
+ f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
77
+ )
78
+
79
+ if negative_sampling_probability is not None and labels is None:
80
+ raise ValueError(
81
+ f"When the negative sampling probability is provide, the labels must be provided as well"
82
+ )
83
+
84
+ self.lengths = lengths
85
+ self.max_tokens_per_batch = max_tokens_per_batch
86
+ self.max_position_embeddings = max_position_embeddings
87
+ self.num_replicas = num_replicas
88
+ self.rank = rank
89
+ self.seed = seed
90
+ self.drop_last = drop_last
91
+ self.negative_sampling_probability = negative_sampling_probability
92
+ self.labels = labels
93
+ # Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
94
+ # http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
95
+ # the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
96
+ # which will call batch_sampler.sample.set_epoch
97
+ self.sampler = SamplePlacerHolder()
98
+
99
+ def __iter__(self) -> Iterator[List[int]]:
100
+
101
+ # deterministically shuffle based on epoch and seed
102
+ g = torch.Generator()
103
+ g.manual_seed(self.seed + self.sampler.epoch)
104
+ indices = torch.randperm(len(self.lengths), generator=g).tolist()
105
+
106
+ # Partition indices for this rank
107
+ indices = indices[self.rank :: self.num_replicas]
108
+
109
+ batch = []
110
+ current_batch_tokens = 0
111
+
112
+ for idx in indices:
113
+ # There is a chance to skip the negative samples to account for the class imbalance
114
+ # in the fine-tuning dataset
115
+ if self.negative_sampling_probability:
116
+ if (
117
+ np.random.random() > self.negative_sampling_probability
118
+ and self.labels[idx] == 0
119
+ ):
120
+ continue
121
+ # We take the minimum of the two because each sequence will be truncated to fit
122
+ # the context window of the model
123
+ sample_length = min(self.lengths[idx], self.max_position_embeddings)
124
+ # If adding this sample would exceed max_tokens_per_batch, yield the current batch
125
+ if (
126
+ current_batch_tokens + sample_length + 2 > self.max_tokens_per_batch
127
+ and batch
128
+ ):
129
+ yield batch
130
+ batch = []
131
+ current_batch_tokens = 0
132
+
133
+ # Add the sample to the current batch
134
+ batch.append(idx)
135
+ # plus extract one for the [END] and [PAD] tokens to separate samples
136
+ current_batch_tokens += sample_length + 2
137
+
138
+ # Yield the last batch if it's not empty and we're not dropping it
139
+ if batch and not self.drop_last:
140
+ yield batch
141
+
142
+ def __len__(self) -> int:
143
+ """
144
+ Estimates the number of batches that will be generated.
145
+
146
+ This is an approximation since the exact number depends on the specific
147
+ sequence lengths and their order.
148
+ """
149
+ if len(self.lengths) == 0:
150
+ return 0
151
+
152
+ # There is a chance to skip the negative samples to account for the class imbalance
153
+ # in the fine-tuning dataset
154
+ if self.negative_sampling_probability:
155
+ truncated_lengths = []
156
+ for length, label in zip(self.lengths, self.labels):
157
+ if (
158
+ np.random.random() > self.negative_sampling_probability
159
+ and label == 0
160
+ ):
161
+ continue
162
+ truncated_lengths.append(length)
163
+ else:
164
+ # We need to truncate the lengths due to the context window limit imposed by the model
165
+ truncated_lengths = [
166
+ min(self.max_position_embeddings, length + 2) for length in self.lengths
167
+ ]
168
+
169
+ # Calculate average sequence length
170
+ avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
171
+
172
+ # Estimate average number of sequences per batch
173
+ seqs_per_batch = self.max_tokens_per_batch // avg_seq_length
174
+
175
+ # Estimate total number of batches
176
+ if self.drop_last:
177
+ # If dropping last incomplete batch
178
+ return len(truncated_lengths) // seqs_per_batch
179
+ else:
180
+ # If keeping last incomplete batch, ensure at least 1 batch
181
+ return max(1, len(truncated_lengths) // seqs_per_batch)
@@ -93,9 +93,9 @@ def generate_single_batch(
93
93
  temperature=temperature,
94
94
  top_p=top_p,
95
95
  top_k=top_k,
96
- bos_token_id=tokenizer.end_token_id,
97
- eos_token_id=tokenizer.end_token_id,
98
- pad_token_id=tokenizer.pad_token_id,
96
+ bos_token_id=model.generation_config.bos_token_id,
97
+ eos_token_id=model.generation_config.eos_token_id,
98
+ pad_token_id=model.generation_config.pad_token_id,
99
99
  do_sample=True,
100
100
  use_cache=True,
101
101
  return_dict_in_generate=True,
@@ -150,15 +150,11 @@ def main(args):
150
150
  attn_implementation=(
151
151
  "flash_attention_2" if is_flash_attn_2_available() else "eager"
152
152
  ),
153
- torch_dtype=(
154
- torch.bfloat16
155
- if is_flash_attn_2_available() and args.use_bfloat16
156
- else torch.float32
157
- ),
158
153
  )
159
154
  .eval()
160
155
  .to(device)
161
156
  )
157
+
162
158
  cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
163
159
  cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
164
160
  cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
@@ -192,6 +188,7 @@ def main(args):
192
188
  LOG.info(f"Top P {args.top_p}")
193
189
  LOG.info(f"Top K {args.top_k}")
194
190
  LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
191
+ LOG.info(f"MEDS format: {args.meds_format}")
195
192
 
196
193
  dataset = load_parquet_as_dataset(args.demographic_data_path)
197
194
  total_rows = len(dataset)
@@ -199,6 +196,7 @@ def main(args):
199
196
  num_of_batches = args.num_of_patients // args.batch_size + 1
200
197
  sequence_to_flush = []
201
198
  current_person_id = 1
199
+ prompt_size = 2 if args.meds_format else START_TOKEN_SIZE
202
200
  for i in range(num_of_batches):
203
201
  LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
204
202
 
@@ -215,7 +213,7 @@ def main(args):
215
213
  <= max_seq_allowed
216
214
  ):
217
215
  random_prompts.append(
218
- cehrgpt_tokenizer.encode(row["concept_ids"][:START_TOKEN_SIZE])
216
+ cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
219
217
  )
220
218
  iter += 1
221
219
  if not random_prompts and iter > 10:
@@ -326,6 +324,11 @@ def create_arg_parser():
326
324
  dest="drop_long_sequences",
327
325
  action="store_true",
328
326
  )
327
+ base_arg_parser.add_argument(
328
+ "--meds_format",
329
+ dest="meds_format",
330
+ action="store_true",
331
+ )
329
332
  return base_arg_parser
330
333
 
331
334
 
@@ -60,6 +60,24 @@ OOV_CONCEPT_MAP = {
60
60
  }
61
61
 
62
62
 
63
+ def extract_gender_concept_id(gender_token: str) -> int:
64
+ if gender_token.startswith("Gender/"):
65
+ return int(gender_token[len("Gender/") :])
66
+ elif gender_token.isnumeric():
67
+ return int(gender_token)
68
+ else:
69
+ return 0
70
+
71
+
72
+ def extract_race_concept_id(race_token: str) -> int:
73
+ if race_token.startswith("Race/"):
74
+ return int(race_token[len("Race/") :])
75
+ elif race_token.isnumeric():
76
+ return int(race_token)
77
+ else:
78
+ return 0
79
+
80
+
63
81
  def create_folder_if_not_exists(output_folder, table_name):
64
82
  if not os.path.isdir(Path(output_folder) / table_name):
65
83
  os.mkdir(Path(output_folder) / table_name)
@@ -288,7 +306,13 @@ def gpt_to_omop_converter_batch(
288
306
  if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
289
307
  continue
290
308
 
291
- p = Person(person_id, start_gender, birth_year, start_race)
309
+ p = Person(
310
+ person_id=person_id,
311
+ gender_concept_id=extract_gender_concept_id(start_gender),
312
+ year_of_birth=birth_year,
313
+ race_concept_id=extract_race_concept_id(start_race),
314
+ )
315
+
292
316
  append_to_dict(omop_export_dict, p, person_id)
293
317
  id_mappings_dict["person"][person_id] = person_id
294
318
  pt_seq_dict[person_id] = " ".join(concept_ids)
@@ -316,7 +340,12 @@ def gpt_to_omop_converter_batch(
316
340
  id_mappings_dict["death"][person_id] = person_id
317
341
  else:
318
342
  try:
319
- visit_concept_id = int(clinical_events[event_idx + 1])
343
+ if clinical_events[event_idx + 1].startswith("Visit/"):
344
+ visit_concept_id = int(
345
+ clinical_events[event_idx + 1][len("Visit/") :]
346
+ )
347
+ else:
348
+ visit_concept_id = int(clinical_events[event_idx + 1])
320
349
  inpatient_visit_indicator = visit_concept_id in [
321
350
  9201,
322
351
  262,
@@ -349,6 +378,7 @@ def gpt_to_omop_converter_batch(
349
378
  visit_occurrence_id
350
379
  ] = person_id
351
380
  visit_occurrence_id += 1
381
+
352
382
  elif event in ATT_TIME_TOKENS:
353
383
  if event[0] == "D":
354
384
  att_date_delta = int(event[1:])
cehrgpt/gpt_utils.py CHANGED
@@ -11,6 +11,7 @@ from cehrgpt.models.special_tokens import (
11
11
  )
12
12
 
13
13
  # Regular expression pattern to match inpatient attendance tokens
14
+ MEDS_CODE_PATTERN = re.compile(r".*/.*")
14
15
  INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
15
16
  DEMOGRAPHIC_PROMPT_SIZE = 4
16
17
 
@@ -194,8 +195,12 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
194
195
  return folder_name
195
196
 
196
197
 
197
- def is_clinical_event(token: str) -> bool:
198
- return token.isnumeric()
198
+ def is_clinical_event(token: str, meds: bool = False) -> bool:
199
+ if token.isnumeric():
200
+ return True
201
+ if meds:
202
+ return bool(MEDS_CODE_PATTERN.match(token))
203
+ return False
199
204
 
200
205
 
201
206
  def is_visit_start(token: str):
@@ -212,6 +217,18 @@ def is_visit_end(token: str) -> bool:
212
217
  return token in ["VE", "[VE]"]
213
218
 
214
219
 
220
+ def is_inpatient_hour_token(token: str) -> bool:
221
+ return token.startswith("i-H")
222
+
223
+
224
+ def extract_time_interval_in_hours(token: str) -> int:
225
+ try:
226
+ hour = int(token[3:])
227
+ return hour
228
+ except ValueError:
229
+ return 0
230
+
231
+
215
232
  def is_att_token(token: str):
216
233
  """
217
234
  Check if the token is an attention token.
@@ -251,6 +268,7 @@ def is_artificial_token(token: str) -> bool:
251
268
  return True
252
269
  if token == END_TOKEN:
253
270
  return True
271
+
254
272
  return False
255
273
 
256
274
 
cehrgpt/models/config.py CHANGED
@@ -121,6 +121,7 @@ class CEHRGPTConfig(PretrainedConfig):
121
121
  bos_token_id=50256,
122
122
  eos_token_id=50256,
123
123
  lab_token_ids=None,
124
+ ve_token_id=None,
124
125
  scale_attn_by_inverse_layer_idx=False,
125
126
  reorder_and_upcast_attn=False,
126
127
  exclude_position_ids=False,
@@ -128,19 +129,27 @@ class CEHRGPTConfig(PretrainedConfig):
128
129
  value_vocab_size=None,
129
130
  include_ttv_prediction=False,
130
131
  use_sub_time_tokenization=True,
132
+ include_motor_time_to_event=True,
133
+ motor_tte_vocab_size=None,
134
+ motor_time_to_event_weight=1.0,
135
+ motor_num_time_pieces=16,
131
136
  token_to_time_token_mapping: Dict[int, List] = None,
132
137
  use_pretrained_embeddings=False,
133
138
  n_pretrained_embeddings_layers=2,
134
139
  pretrained_embedding_dim=768,
135
140
  pretrained_token_ids: List[int] = None,
141
+ next_token_prediction_loss_weight=1.0,
136
142
  time_token_loss_weight=1.0,
137
143
  time_to_visit_loss_weight=1.0,
138
144
  causal_sfm=False,
139
145
  demographics_size=4,
140
146
  lab_token_penalty=False,
141
147
  lab_token_loss_weight=0.9,
148
+ value_prediction_loss_weight=1.0,
142
149
  entropy_penalty=False,
143
150
  entropy_penalty_alpha=0.01,
151
+ sample_packing_max_positions=None,
152
+ class_weights=None,
144
153
  **kwargs,
145
154
  ):
146
155
  if token_to_time_token_mapping is None:
@@ -150,6 +159,11 @@ class CEHRGPTConfig(PretrainedConfig):
150
159
  self.vocab_size = vocab_size
151
160
  self.time_token_vocab_size = time_token_vocab_size
152
161
  self.n_positions = n_positions
162
+ self.sample_packing_max_positions = (
163
+ sample_packing_max_positions
164
+ if sample_packing_max_positions
165
+ else n_positions
166
+ )
153
167
  self.n_embd = n_embd
154
168
  self.n_layer = n_layer
155
169
  self.n_head = n_head
@@ -178,11 +192,28 @@ class CEHRGPTConfig(PretrainedConfig):
178
192
  self.include_values = include_values
179
193
  self.value_vocab_size = value_vocab_size
180
194
 
195
+ self.next_token_prediction_loss_weight = next_token_prediction_loss_weight
181
196
  self.include_ttv_prediction = include_ttv_prediction
182
197
  self.use_sub_time_tokenization = use_sub_time_tokenization
183
198
  self._token_to_time_token_mapping = token_to_time_token_mapping
184
199
  self.time_token_loss_weight = time_token_loss_weight
185
200
  self.time_to_visit_loss_weight = time_to_visit_loss_weight
201
+
202
+ # MOTOR TTE configuration
203
+ self.motor_tte_vocab_size = motor_tte_vocab_size
204
+ self.include_motor_time_to_event = (
205
+ include_motor_time_to_event
206
+ and self.motor_tte_vocab_size
207
+ and self.motor_tte_vocab_size > 0
208
+ )
209
+ if self.include_motor_time_to_event and not ve_token_id:
210
+ raise RuntimeError(
211
+ f"ve_token_id must be provided when include_motor_time_to_event is True"
212
+ )
213
+ self.ve_token_id = ve_token_id
214
+ self.motor_time_to_event_weight = motor_time_to_event_weight
215
+ self.motor_num_time_pieces = motor_num_time_pieces
216
+
186
217
  self.causal_sfm = causal_sfm
187
218
  self.demographics_size = demographics_size
188
219
  self.use_pretrained_embeddings = use_pretrained_embeddings
@@ -195,6 +226,10 @@ class CEHRGPTConfig(PretrainedConfig):
195
226
  self.lab_token_loss_weight = lab_token_loss_weight
196
227
  self.entropy_penalty = entropy_penalty
197
228
  self.entropy_penalty_alpha = entropy_penalty_alpha
229
+ self.value_prediction_loss_weight = value_prediction_loss_weight
230
+
231
+ # Class weights for fine-tuning
232
+ self.class_weights = class_weights
198
233
 
199
234
  kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
200
235