cehrgpt 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -162,6 +162,22 @@ class CehrGptDataCollator:
162
162
  f"batch['input_ids']: {batch['input_ids']} "
163
163
  )
164
164
 
165
+ if "epoch_times" in examples[0]:
166
+ batch_epoch_times = [
167
+ self._try_reverse_tensor(
168
+ self._convert_to_tensor(example["epoch_times"])
169
+ )
170
+ for example in examples
171
+ ]
172
+ # Pad sequences to the max length in the batch
173
+ batch["epoch_times"] = self._try_reverse_tensor(
174
+ pad_sequence(
175
+ batch_epoch_times,
176
+ batch_first=True,
177
+ padding_value=0,
178
+ ).to(torch.float32)
179
+ )
180
+
165
181
  if "position_ids" in examples[0]:
166
182
  batch_position_ids = [
167
183
  self._try_reverse_tensor(
@@ -663,7 +679,9 @@ class CehrGptDataCollator:
663
679
 
664
680
  # Subtract one for the [END] token when sample_packing is not enabled
665
681
  new_max_length = (
666
- max_length_allowed if sample_packing else max_length_allowed - 1
682
+ max_length_allowed - 1
683
+ if not sample_packing and self.pretraining
684
+ else max_length_allowed
667
685
  )
668
686
 
669
687
  if self.include_ttv_prediction:
@@ -685,13 +703,20 @@ class CehrGptDataCollator:
685
703
 
686
704
  # Return the record directly if the actual sequence length is less than the max sequence
687
705
  if seq_length <= new_max_length:
688
- if not sample_packing:
706
+ if not sample_packing and self.pretraining:
689
707
  record["input_ids"] = torch.concat(
690
708
  [
691
709
  self._convert_to_tensor(record["input_ids"]),
692
710
  self._convert_to_tensor([eos_token]),
693
711
  ]
694
712
  )
713
+ if "epoch_times" in record:
714
+ record["epoch_times"] = torch.concat(
715
+ [
716
+ self._convert_to_tensor(record["epoch_times"]),
717
+ self._convert_to_tensor([record["epoch_times"][-1]]),
718
+ ]
719
+ )
695
720
  if self.include_values:
696
721
  record["value_indicators"] = torch.concat(
697
722
  [
@@ -727,6 +752,10 @@ class CehrGptDataCollator:
727
752
  record["input_ids"] = self._convert_to_tensor(
728
753
  record["input_ids"][start_index : end_index + 1]
729
754
  )
755
+ if "epoch_times" in record:
756
+ record["epoch_times"] = self._convert_to_tensor(
757
+ record["epoch_times"][start_index : end_index + 1]
758
+ )
730
759
  if self.include_values:
731
760
  record["value_indicators"] = self._convert_to_tensor(
732
761
  record["value_indicators"][start_index : end_index + 1]
@@ -760,6 +789,11 @@ class CehrGptDataCollator:
760
789
  if sample_packing and "position_ids" in record:
761
790
  record["position_ids"] = record["position_ids"][0:end_index]
762
791
 
792
+ if "epoch_times" in record:
793
+ record["epoch_times"] = self._convert_to_tensor(
794
+ record["epoch_times"][0:end_index]
795
+ )
796
+
763
797
  if self.include_values:
764
798
  record["value_indicators"] = self._convert_to_tensor(
765
799
  record["value_indicators"][0:end_index]
@@ -792,6 +826,17 @@ class CehrGptDataCollator:
792
826
  ),
793
827
  ]
794
828
  )
829
+ if "epoch_times" in record:
830
+ record["epoch_times"] = torch.concat(
831
+ [
832
+ torch.zeros(
833
+ [record["epoch_times"][0]], dtype=torch.float32
834
+ ),
835
+ self._convert_to_tensor(
836
+ record["epoch_times"][token_index:seq_length]
837
+ ),
838
+ ]
839
+ )
795
840
  if self.include_values:
796
841
  record["value_indicators"] = torch.concat(
797
842
  [
@@ -830,7 +875,7 @@ class CehrGptDataCollator:
830
875
  )
831
876
  break
832
877
  else:
833
- start_index = seq_length - new_max_length
878
+ start_index = max(seq_length - new_max_length, 0)
834
879
  end_index = seq_length
835
880
  for i in range(start_index, end_index):
836
881
  current_token = record["input_ids"][i]
@@ -842,6 +887,11 @@ class CehrGptDataCollator:
842
887
  ]
843
888
  if sample_packing and "position_ids" in record:
844
889
  record["position_ids"] = record["position_ids"][i:end_index]
890
+
891
+ if "epoch_times" in record:
892
+ record["epoch_times"] = self._convert_to_tensor(
893
+ record["epoch_times"][i:end_index]
894
+ )
845
895
  if self.include_values:
846
896
  record["value_indicators"] = record["value_indicators"][
847
897
  i:end_index
@@ -863,6 +913,10 @@ class CehrGptDataCollator:
863
913
  ]
864
914
  if sample_packing and "position_ids" in record:
865
915
  record["position_ids"] = record["position_ids"][-new_max_length:]
916
+ if "epoch_times" in record:
917
+ record["epoch_times"] = self._convert_to_tensor(
918
+ record["epoch_times"][-new_max_length:]
919
+ )
866
920
  if self.include_values:
867
921
  record["value_indicators"] = record["value_indicators"][
868
922
  -new_max_length:
@@ -873,36 +927,6 @@ class CehrGptDataCollator:
873
927
  -new_max_length:
874
928
  ]
875
929
 
876
- if not sample_packing:
877
- # Finally we add the end token to the end of the sequence
878
- record["input_ids"] = torch.concat(
879
- [
880
- self._convert_to_tensor(record["input_ids"]),
881
- self._convert_to_tensor([eos_token]),
882
- ]
883
- )
884
- if self.include_values:
885
- record["value_indicators"] = torch.concat(
886
- [
887
- self._convert_to_tensor(record["value_indicators"]),
888
- self._convert_to_tensor([False]),
889
- ]
890
- ).to(torch.bool)
891
- record["values"] = torch.concat(
892
- [
893
- self._convert_to_tensor(record["values"]),
894
- self._convert_to_tensor(
895
- [self.tokenizer.pad_value_token_id]
896
- ),
897
- ]
898
- )
899
- if self.include_ttv_prediction:
900
- record["time_to_visits"] = torch.concat(
901
- [
902
- record["time_to_visits"],
903
- self._convert_to_tensor([-100.0]),
904
- ]
905
- )
906
930
  return record
907
931
 
908
932
 
@@ -21,7 +21,6 @@ from cehrbert_data.const.artificial_tokens import (
21
21
  DISCHARGE_UNKNOWN_TOKEN,
22
22
  GENDER_UNKNOWN_TOKEN,
23
23
  RACE_UNKNOWN_TOKEN,
24
- VISIT_UNKNOWN_TOKEN,
25
24
  )
26
25
  from cehrbert_data.const.common import NA
27
26
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
@@ -47,10 +46,16 @@ CEHRGPT_COLUMNS = [
47
46
  ]
48
47
 
49
48
 
50
- def convert_date_to_posix_time(index_date: datetime.date) -> float:
51
- return datetime.datetime.combine(
52
- index_date, datetime.datetime.min.time()
53
- ).timestamp()
49
+ def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
50
+ if isinstance(index_date, datetime.date):
51
+ return (
52
+ datetime.datetime.combine(index_date, datetime.datetime.min.time())
53
+ .replace(tzinfo=datetime.timezone.utc)
54
+ .timestamp()
55
+ )
56
+ elif isinstance(index_date, datetime.datetime):
57
+ return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
58
+ return index_date
54
59
 
55
60
 
56
61
  class DatasetMappingDecorator(DatasetMapping):
@@ -128,7 +133,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
128
133
  cehrgpt_record["concept_as_values"].append(concept_as_value)
129
134
  cehrgpt_record["units"].append(unit)
130
135
  cehrgpt_record["is_numeric_types"].append(is_numeric_type)
131
- cehrgpt_record["epoch_times"].append(time.timestamp())
136
+ cehrgpt_record["epoch_times"].append(
137
+ time.replace(tzinfo=datetime.timezone.utc).timestamp()
138
+ )
132
139
 
133
140
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
134
141
  cehrgpt_record = {
@@ -360,7 +367,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
360
367
  cehrgpt_record["num_of_visits"] = len(visits)
361
368
 
362
369
  if record.get("index_date", None) is not None:
363
- cehrgpt_record["index_date"] = record["index_date"]
370
+ cehrgpt_record["index_date"] = (
371
+ record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
372
+ )
364
373
  if record.get("label", None) is not None:
365
374
  cehrgpt_record["label"] = record["label"]
366
375
  if record.get("age_at_index", None) is not None:
@@ -529,9 +538,13 @@ class ExtractTokenizedSequenceDataMapping:
529
538
  prediction_start_end_times = [
530
539
  (
531
540
  self._calculate_prediction_start_time(
532
- prediction_time_label_map["index_date"].timestamp()
541
+ prediction_time_label_map["index_date"]
542
+ .replace(tzinfo=datetime.timezone.utc)
543
+ .timestamp()
533
544
  ),
534
- prediction_time_label_map["index_date"].timestamp(),
545
+ prediction_time_label_map["index_date"]
546
+ .replace(tzinfo=datetime.timezone.utc)
547
+ .timestamp(),
535
548
  prediction_time_label_map["label"],
536
549
  )
537
550
  for prediction_time_label_map in prediction_times
@@ -0,0 +1,314 @@
1
+ import datetime
2
+ import os
3
+ import random
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Any, Dict
7
+
8
+ import numpy as np
9
+ import polars as pl
10
+ import torch
11
+ import torch.distributed as dist
12
+ from cehrbert.runners.runner_util import generate_prepared_ds_path
13
+ from datasets import load_from_disk
14
+ from meds import held_out_split, train_split, tuning_split
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+ from transformers.trainer_utils import is_main_process
18
+ from transformers.utils import is_flash_attn_2_available, logging
19
+
20
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
21
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
22
+ from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
23
+ generate_single_batch,
24
+ normalize_value,
25
+ )
26
+ from cehrgpt.gpt_utils import (
27
+ extract_time_interval_in_days,
28
+ extract_time_interval_in_hours,
29
+ is_att_token,
30
+ is_inpatient_hour_token,
31
+ is_visit_end,
32
+ is_visit_start,
33
+ )
34
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
35
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
36
+ from cehrgpt.runners.data_utils import (
37
+ extract_cohort_sequences,
38
+ prepare_finetune_dataset,
39
+ )
40
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
41
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
42
+
43
+ LOG = logging.get_logger("transformers")
44
+
45
+
46
+ def map_data_split_name(split: str) -> str:
47
+ if split == "train":
48
+ return train_split
49
+ elif split == "validation":
50
+ return tuning_split
51
+ elif split == "test":
52
+ return held_out_split
53
+ raise ValueError(f"Unknown split: {split}")
54
+
55
+
56
+ def seed_all(seed: int = 42):
57
+ """Set seed for Python, NumPy, and PyTorch (CPU & CUDA)."""
58
+ random.seed(seed) # Python random
59
+ np.random.seed(seed) # NumPy
60
+ torch.manual_seed(seed) # PyTorch CPU
61
+ torch.cuda.manual_seed(seed) # Current GPU
62
+ torch.cuda.manual_seed_all(seed) # All GPUs
63
+
64
+ # For reproducibility in dataloader workers
65
+ os.environ["PYTHONHASHSEED"] = str(seed)
66
+
67
+
68
+ def generate_trajectories_per_batch(
69
+ batch: Dict[str, Any],
70
+ cehrgpt_tokenizer: CehrGptTokenizer,
71
+ cehrgpt_model: CEHRGPT2LMHeadModel,
72
+ device,
73
+ data_output_path: Path,
74
+ max_length: int,
75
+ ):
76
+ subject_ids = batch["person_id"].squeeze().detach().cpu().tolist()
77
+ prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
78
+ batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
79
+ batched_input_ids = batch["input_ids"]
80
+ batched_value_indicators = batch["value_indicators"]
81
+ batched_values = batch["values"]
82
+ # Make sure the batch does not exceed batch_size
83
+ batch_sequences = generate_single_batch(
84
+ cehrgpt_model,
85
+ cehrgpt_tokenizer,
86
+ batched_input_ids,
87
+ values=batched_values,
88
+ value_indicators=batched_value_indicators,
89
+ max_length=max_length,
90
+ top_p=1.0,
91
+ top_k=cehrgpt_tokenizer.vocab_size,
92
+ device=device,
93
+ )
94
+ # Clear the cache
95
+ torch.cuda.empty_cache()
96
+
97
+ trajectories = []
98
+ for sample_i, (concept_ids, value_indicators, values) in enumerate(
99
+ zip(
100
+ batch_sequences["sequences"],
101
+ batch_sequences["value_indicators"],
102
+ batch_sequences["values"],
103
+ )
104
+ ):
105
+ (
106
+ concept_ids,
107
+ is_numeric_types,
108
+ number_as_values,
109
+ concept_as_values,
110
+ units,
111
+ ) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
112
+
113
+ epoch_times = batched_epoch_times[sample_i]
114
+ input_length = len(epoch_times)
115
+ # Getting the last observed event time from the token before the prediction time
116
+ window_last_observed = epoch_times[input_length - 1]
117
+ current_cursor = epoch_times[-1]
118
+ generated_epoch_times = []
119
+ valid_indices = []
120
+
121
+ for i in range(input_length, len(concept_ids)):
122
+ concept_id = concept_ids[i]
123
+ # We use the left padding strategy in the data collator
124
+ if concept_id in [cehrgpt_tokenizer.pad_token, cehrgpt_tokenizer.end_token]:
125
+ continue
126
+ # We need to construct the time stamp
127
+ if is_att_token(concept_id):
128
+ current_cursor += extract_time_interval_in_days(concept_id) * 24 * 3600
129
+ elif is_inpatient_hour_token(concept_id):
130
+ current_cursor += extract_time_interval_in_hours(concept_id) * 3600
131
+ elif is_visit_start(concept_id) or is_visit_end(concept_id):
132
+ continue
133
+ else:
134
+ valid_indices.append(i)
135
+ generated_epoch_times.append(
136
+ datetime.datetime.utcfromtimestamp(current_cursor).replace(
137
+ tzinfo=None
138
+ )
139
+ )
140
+
141
+ trajectories.append(
142
+ {
143
+ "subject_id": subject_ids[sample_i],
144
+ "prediction_time": datetime.datetime.utcfromtimestamp(
145
+ prediction_times[sample_i]
146
+ ).replace(tzinfo=None),
147
+ "window_last_observed_time": datetime.datetime.utcfromtimestamp(
148
+ window_last_observed
149
+ ).replace(tzinfo=None),
150
+ "times": generated_epoch_times,
151
+ "concept_ids": np.asarray(concept_ids)[valid_indices].tolist(),
152
+ "numeric_values": np.asarray(number_as_values)[valid_indices].tolist(),
153
+ "text_value": np.asarray(concept_as_values)[valid_indices].tolist(),
154
+ "units": np.asarray(units)[valid_indices].tolist(),
155
+ }
156
+ )
157
+
158
+ trajectories = (
159
+ pl.DataFrame(trajectories)
160
+ .explode(["times", "concept_ids", "numeric_values", "text_value", "units"])
161
+ .rename(
162
+ {
163
+ "times": "time",
164
+ "concept_ids": "code",
165
+ "numeric_values": "numeric_value",
166
+ "units": "unit",
167
+ }
168
+ )
169
+ .select(
170
+ "subject_id",
171
+ "prediction_time",
172
+ "window_last_observed_time",
173
+ "time",
174
+ "code",
175
+ "numeric_value",
176
+ "text_value",
177
+ "unit",
178
+ )
179
+ )
180
+ trajectories.write_parquet(data_output_path)
181
+
182
+
183
+ def main():
184
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
185
+ if torch.cuda.is_available():
186
+ device = torch.device("cuda")
187
+ else:
188
+ device = torch.device("cpu")
189
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
190
+ model_args.tokenizer_name_or_path
191
+ )
192
+ cehrgpt_model = (
193
+ CEHRGPT2LMHeadModel.from_pretrained(
194
+ model_args.model_name_or_path,
195
+ attn_implementation=(
196
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
197
+ ),
198
+ )
199
+ .eval()
200
+ .to(device)
201
+ )
202
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
203
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
204
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
205
+
206
+ if not os.path.exists(training_args.output_dir):
207
+ os.makedirs(training_args.output_dir)
208
+
209
+ prepared_ds_path = generate_prepared_ds_path(
210
+ data_args, model_args, data_folder=data_args.cohort_folder
211
+ )
212
+
213
+ processed_dataset = None
214
+ if any(prepared_ds_path.glob("*")):
215
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
216
+ processed_dataset = load_from_disk(str(prepared_ds_path))
217
+ LOG.info("Prepared dataset loaded from disk...")
218
+ if cehrgpt_args.expand_tokenizer:
219
+ if tokenizer_exists(training_args.output_dir):
220
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
221
+ training_args.output_dir
222
+ )
223
+ else:
224
+ LOG.warning(
225
+ f"CehrGptTokenizer must exist in {training_args.output_dir} "
226
+ f"when the dataset has been processed and expand_tokenizer is set to True. "
227
+ f"Please delete the processed dataset at {prepared_ds_path}."
228
+ )
229
+ processed_dataset = None
230
+ shutil.rmtree(prepared_ds_path)
231
+
232
+ if processed_dataset is None and is_main_process(training_args.local_rank):
233
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
234
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
235
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
236
+ processed_dataset = extract_cohort_sequences(data_args, cehrgpt_args)
237
+ else:
238
+ # Organize them into a single DatasetDict
239
+ final_splits = prepare_finetune_dataset(
240
+ data_args, training_args, cehrgpt_args
241
+ )
242
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
243
+ if not data_args.streaming:
244
+ all_columns = final_splits["train"].column_names
245
+ if "visit_concept_ids" in all_columns:
246
+ final_splits = final_splits.remove_columns(["visit_concept_ids"])
247
+
248
+ processed_dataset = create_cehrgpt_finetuning_dataset(
249
+ dataset=final_splits,
250
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
251
+ data_args=data_args,
252
+ )
253
+ if not data_args.streaming:
254
+ processed_dataset.save_to_disk(prepared_ds_path)
255
+ processed_dataset.cleanup_cache_files()
256
+
257
+ # After main-process-only operations, synchronize all processes to ensure consistency
258
+ if dist.is_available() and dist.is_initialized():
259
+ dist.barrier()
260
+
261
+ # We suppress the additional learning objectives in fine-tuning
262
+ data_collator = CehrGptDataCollator(
263
+ tokenizer=cehrgpt_tokenizer,
264
+ max_length=cehrgpt_args.generation_input_length,
265
+ include_values=cehrgpt_model.config.include_values,
266
+ pretraining=False,
267
+ include_ttv_prediction=False,
268
+ use_sub_time_tokenization=False,
269
+ include_demographics=False,
270
+ add_linear_prob_token=False,
271
+ )
272
+
273
+ LOG.info(
274
+ "Generating %s trajectories per sample",
275
+ cehrgpt_args.num_of_trajectories_per_sample,
276
+ )
277
+ for sample_i in range(cehrgpt_args.num_of_trajectories_per_sample):
278
+ for split, dataset in processed_dataset.items():
279
+ meds_split = map_data_split_name(split)
280
+ dataloader = DataLoader(
281
+ dataset=dataset,
282
+ batch_size=training_args.per_device_eval_batch_size,
283
+ num_workers=training_args.dataloader_num_workers,
284
+ collate_fn=data_collator,
285
+ pin_memory=training_args.dataloader_pin_memory,
286
+ )
287
+ sample_output_dir = (
288
+ Path(training_args.output_dir) / meds_split / f"{sample_i}"
289
+ )
290
+ sample_output_dir.mkdir(exist_ok=True, parents=True)
291
+ for batch_i, batch in tqdm(
292
+ enumerate(dataloader),
293
+ desc=f"Generating Trajectories for split {meds_split} with trajectory {sample_i + 1}",
294
+ ):
295
+ output_parquet_file = sample_output_dir / f"{batch_i}.parquet"
296
+ if output_parquet_file.exists():
297
+ LOG.info("%s already exists, skip...", output_parquet_file)
298
+ continue
299
+
300
+ generate_trajectories_per_batch(
301
+ batch,
302
+ cehrgpt_tokenizer,
303
+ cehrgpt_model,
304
+ device,
305
+ sample_output_dir / f"{batch_i}.parquet",
306
+ cehrgpt_args.generation_max_new_tokens
307
+ + cehrgpt_args.generation_input_length,
308
+ )
309
+
310
+
311
+ if __name__ == "__main__":
312
+ # ✅ Call first thing inside main()
313
+ seed_all(42)
314
+ main()
@@ -74,7 +74,10 @@ def generate_single_batch(
74
74
  model: CEHRGPT2LMHeadModel,
75
75
  tokenizer: CehrGptTokenizer,
76
76
  prompts: List[List[int]],
77
- max_new_tokens=512,
77
+ max_length: int,
78
+ values: Optional[torch.Tensor] = None,
79
+ value_indicators: Optional[torch.Tensor] = None,
80
+ max_new_tokens: Optional[int] = None,
78
81
  mini_num_of_concepts=1,
79
82
  top_p=0.95,
80
83
  top_k=50,
@@ -88,7 +91,8 @@ def generate_single_batch(
88
91
  with torch.no_grad():
89
92
  generation_config = GenerationConfig(
90
93
  repetition_penalty=repetition_penalty,
91
- max_length=max_new_tokens,
94
+ max_new_tokens=max_new_tokens,
95
+ max_length=max_length,
92
96
  min_length=mini_num_of_concepts,
93
97
  temperature=temperature,
94
98
  top_p=top_p,
@@ -107,9 +111,17 @@ def generate_single_batch(
107
111
  num_beam_groups=num_beam_groups,
108
112
  epsilon_cutoff=epsilon_cutoff,
109
113
  )
114
+
110
115
  batched_prompts = torch.tensor(prompts).to(device)
116
+ if values is not None:
117
+ values = values.to(device)
118
+ if value_indicators is not None:
119
+ value_indicators = value_indicators.to(device)
120
+
111
121
  results = model.generate(
112
122
  inputs=batched_prompts,
123
+ values=values,
124
+ value_indicators=value_indicators,
113
125
  generation_config=generation_config,
114
126
  lab_token_ids=tokenizer.lab_token_ids,
115
127
  )
@@ -226,7 +238,7 @@ def main(args):
226
238
  cehrgpt_model,
227
239
  cehrgpt_tokenizer,
228
240
  random_prompts[: args.batch_size],
229
- max_new_tokens=args.context_window,
241
+ max_length=args.context_window,
230
242
  mini_num_of_concepts=args.min_num_of_concepts,
231
243
  top_p=args.top_p,
232
244
  top_k=args.top_k,
@@ -102,7 +102,9 @@ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
102
102
  attention_mask = attention_mask.flip(dims=[1])
103
103
 
104
104
  nonzero_counts = attention_mask.sum(dim=1)
105
- max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
105
+ max_token_positions = torch.argmax(
106
+ attention_mask.to(torch.int32).flip(dims=[1]), dim=1
107
+ )
106
108
  max_indices = attention_mask.shape[1] - 1 - max_token_positions
107
109
  return torch.any(nonzero_counts < (max_indices + 1)).item()
108
110
 
@@ -1848,6 +1850,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1848
1850
 
1849
1851
  # keep track of which sequences are already finished
1850
1852
  batch_size, cur_len = input_ids.shape
1853
+ model_kwargs["attention_mask"] = input_ids != pad_token_id
1851
1854
  if "inputs_embeds" in model_kwargs:
1852
1855
  cur_len = model_kwargs["inputs_embeds"].shape[1]
1853
1856
  this_peer_finished = False
@@ -1866,11 +1869,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1866
1869
  [] if self.config.lab_token_ids is None else self.config.lab_token_ids,
1867
1870
  dtype=torch.int32,
1868
1871
  )
1869
- value_indicators = torch.zeros_like(input_ids).to(torch.bool)
1870
- values = torch.zeros_like(
1871
- input_ids,
1872
- dtype=torch.int32,
1873
- )
1872
+
1873
+ if model_kwargs.get("value_indicators", None) is not None:
1874
+ value_indicators = model_kwargs.get("value_indicators")
1875
+ else:
1876
+ value_indicators = torch.zeros_like(input_ids).to(torch.bool)
1877
+
1878
+ if model_kwargs.get("values", None) is not None:
1879
+ values = model_kwargs.get("values")
1880
+ else:
1881
+ values = torch.zeros_like(
1882
+ input_ids,
1883
+ dtype=torch.int32,
1884
+ )
1874
1885
  # Generate initial random_vectors
1875
1886
  if self.cehrgpt.config.causal_sfm:
1876
1887
  model_kwargs["random_vectors"] = torch.rand(
@@ -47,7 +47,7 @@ def prepare_finetune_dataset(
47
47
  data_args: DataTrainingArguments,
48
48
  training_args: TrainingArguments,
49
49
  cehrgpt_args: CehrGPTArguments,
50
- cache_file_collector: CacheFileCollector,
50
+ cache_file_collector: Optional[CacheFileCollector] = None,
51
51
  ) -> DatasetDict:
52
52
  # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
53
53
  if data_args.is_data_in_meds:
@@ -91,8 +91,9 @@ def prepare_finetune_dataset(
91
91
  "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
92
92
  stats,
93
93
  )
94
- # Clean up the files created from the data generator
95
- cache_file_collector.remove_cache_files()
94
+ if cache_file_collector:
95
+ # Clean up the files created from the data generator
96
+ cache_file_collector.remove_cache_files()
96
97
  dataset = load_from_disk(str(meds_extension_path))
97
98
 
98
99
  train_set = dataset["train"]
@@ -271,7 +272,7 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
271
272
  def extract_cohort_sequences(
272
273
  data_args: DataTrainingArguments,
273
274
  cehrgpt_args: CehrGPTArguments,
274
- cache_file_collector: CacheFileCollector,
275
+ cache_file_collector: Optional[CacheFileCollector] = None,
275
276
  ) -> DatasetDict:
276
277
  """
277
278
  Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
@@ -309,9 +310,18 @@ def extract_cohort_sequences(
309
310
  mapping={
310
311
  "prediction_time": "index_date",
311
312
  "subject_id": "person_id",
313
+ "boolean_value": "label",
312
314
  }
313
315
  )
314
316
  all_person_ids = cohort["person_id"].unique().to_list()
317
+ # In case the label column does not exist, we add a fake column to the dataframe so subsequent process can work
318
+ if "label" not in cohort.columns:
319
+ cohort = cohort.with_columns(
320
+ pl.Series(
321
+ name="label", values=np.zeros_like(cohort["person_id"].to_numpy())
322
+ )
323
+ )
324
+
315
325
  # data_args.observation_window
316
326
  tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
317
327
  filtered_tokenized_dataset = tokenized_dataset.filter(
@@ -353,6 +363,7 @@ def extract_cohort_sequences(
353
363
  num_proc=data_args.preprocessing_num_workers,
354
364
  remove_columns=filtered_tokenized_dataset["train"].column_names,
355
365
  )
356
- cache_file_collector.add_cache_files(filtered_tokenized_dataset)
357
- cache_file_collector.add_cache_files(processed_dataset)
366
+ if cache_file_collector:
367
+ cache_file_collector.add_cache_files(filtered_tokenized_dataset)
368
+ cache_file_collector.add_cache_files(processed_dataset)
358
369
  return processed_dataset
@@ -580,7 +580,15 @@ def do_predict(
580
580
  index_dates = batch.pop("index_date").numpy().squeeze()
581
581
  if index_dates.ndim == 0:
582
582
  index_dates = np.asarray([index_dates])
583
- index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
583
+
584
+ index_dates = list(
585
+ map(
586
+ lambda posix_time: datetime.utcfromtimestamp(posix_time).replace(
587
+ tzinfo=None
588
+ ),
589
+ index_dates.tolist(),
590
+ )
591
+ )
584
592
 
585
593
  batch = {k: v.to(device) for k, v in batch.items()}
586
594
  # Forward pass
@@ -229,3 +229,15 @@ class CehrGPTArguments:
229
229
  "help": "The probability of negative samples will be included in the training data"
230
230
  },
231
231
  )
232
+ num_of_trajectories_per_sample: Optional[int] = dataclasses.field(
233
+ default=1,
234
+ metadata={"help": "The number of trajectories per sample"},
235
+ )
236
+ generation_input_length: Optional[int] = dataclasses.field(
237
+ default=1024,
238
+ metadata={"help": "The length of the input sequence"},
239
+ )
240
+ generation_max_new_tokens: Optional[int] = dataclasses.field(
241
+ default=1024,
242
+ metadata={"help": "The maximum number of tokens in the generation sequence"},
243
+ )
@@ -1,15 +1,14 @@
1
+ import datetime
1
2
  import glob
2
3
  import os
3
4
  import shutil
4
5
  import uuid
5
- from datetime import datetime
6
6
  from functools import partial
7
7
  from pathlib import Path
8
8
  from typing import Optional, Union
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
12
- import polars as pl
13
12
  import torch
14
13
  import torch.distributed as dist
15
14
  from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
@@ -25,7 +24,6 @@ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
25
24
  CehrGptDataCollator,
26
25
  SamplePackingCehrGptDataCollator,
27
26
  )
28
- from cehrgpt.data.hf_cehrgpt_dataset_mapping import ExtractTokenizedSequenceDataMapping
29
27
  from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
30
28
  from cehrgpt.models.hf_cehrgpt import (
31
29
  CEHRGPT2Model,
@@ -159,24 +157,7 @@ def main():
159
157
  final_splits = prepare_finetune_dataset(
160
158
  data_args, training_args, cehrgpt_args, cache_file_collector
161
159
  )
162
- if cehrgpt_args.expand_tokenizer:
163
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
164
- if tokenizer_exists(new_tokenizer_path):
165
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
166
- new_tokenizer_path
167
- )
168
- else:
169
- cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
170
- cehrgpt_tokenizer=cehrgpt_tokenizer,
171
- dataset=final_splits["train"],
172
- data_args=data_args,
173
- concept_name_mapping={},
174
- )
175
- cehrgpt_tokenizer.save_pretrained(
176
- os.path.expanduser(training_args.output_dir)
177
- )
178
-
179
- # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
160
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
180
161
  if not data_args.streaming:
181
162
  all_columns = final_splits["train"].column_names
182
163
  if "visit_concept_ids" in all_columns:
@@ -238,10 +219,6 @@ def main():
238
219
  len(processed_dataset["test"]),
239
220
  )
240
221
 
241
- LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
242
- LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
243
- if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
244
- cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
245
222
  if (
246
223
  cehrgpt_model.config.max_position_embeddings
247
224
  < model_args.max_position_embeddings
@@ -339,10 +316,12 @@ def main():
339
316
  for data_dir in [data_args.data_folder, data_args.test_data_folder]
340
317
  ]
341
318
  )
342
- # This is a pre-caution in case the index_date is not a datetime type
343
- demographics_df["index_date"] = pd.to_datetime(
344
- demographics_df["index_date"]
345
- ).dt.date
319
+
320
+ demographics_df["index_date"] = (
321
+ demographics_df["index_date"].dt.tz_localize("UTC")
322
+ - datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
323
+ ).dt.total_seconds()
324
+
346
325
  demographics_dict = {
347
326
  (row["person_id"], row["index_date"]): {
348
327
  "gender_concept_id": row["gender_concept_id"],
@@ -379,9 +358,16 @@ def main():
379
358
  prediction_time_posix = batch.pop("index_date").numpy().squeeze()
380
359
  if prediction_time_posix.ndim == 0:
381
360
  prediction_time_posix = np.asarray([prediction_time_posix])
361
+
382
362
  prediction_time = list(
383
- map(datetime.fromtimestamp, prediction_time_posix)
363
+ map(
364
+ lambda posix_time: datetime.datetime.utcfromtimestamp(
365
+ posix_time
366
+ ).replace(tzinfo=None),
367
+ prediction_time_posix,
368
+ )
384
369
  )
370
+
385
371
  labels = (
386
372
  batch.pop("classifier_label")
387
373
  .float()
@@ -393,6 +379,10 @@ def main():
393
379
  if labels.ndim == 0:
394
380
  labels = np.asarray([labels])
395
381
 
382
+ # Right now the model does not support this column, we need to pop it
383
+ if "epoch_times" in batch:
384
+ batch.pop("epoch_times")
385
+
396
386
  batch = {k: v.to(device) for k, v in batch.items()}
397
387
  # Forward pass
398
388
  cehrgpt_output = cehrgpt_model(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -105,6 +105,100 @@ sh scripts/omop_pipeline.sh \
105
105
  $OMOP_VOCAB_DIR
106
106
  ```
107
107
 
108
+ # MEDS Support
109
+
110
+ This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
111
+
112
+ ## Prerequisites
113
+
114
+ Set up the required environment variables before beginning:
115
+
116
+ ```bash
117
+ export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
118
+ export MEDS_DIR="" # Path to MEDS data directory
119
+ export MEDS_READER_DIR="" # Path to MEDS reader output directory
120
+ ```
121
+
122
+ ## Step 1: Create MIMIC MEDS Data
123
+
124
+ Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
125
+
126
+ ## Step 2: Create the MEDS Reader
127
+
128
+ Convert the MEDS data for use with CEHR-GPT:
129
+
130
+ ```bash
131
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
132
+ ```
133
+
134
+ ## Step 3: Pretrain CEHR-GPT
135
+
136
+ Run the pretraining process using the prepared MEDS data:
137
+
138
+ ```bash
139
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
140
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
141
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
142
+ --output_dir $CEHR_GPT_MODEL_DIR \
143
+ --data_folder $MEDS_READER_DIR \
144
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
145
+ --do_train true --seed 42 \
146
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
147
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
148
+ --evaluation_strategy epoch --save_strategy epoch \
149
+ --sample_packing --max_tokens_per_batch 16384 \
150
+ --warmup_steps 500 --weight_decay 0.01 \
151
+ --num_train_epochs 50 --learning_rate 0.0002 \
152
+ --use_early_stopping --early_stopping_threshold 0.001 \
153
+ --is_data_in_meds --inpatient_att_function_type day \
154
+ --att_function_type day --include_inpatient_hour_token \
155
+ --include_auxiliary_token --include_demographic_prompt \
156
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
157
+ ```
158
+
159
+ ## Step 4: Generate MEDS Trajectories
160
+
161
+ ### Environment Setup for Trajectory Generation
162
+
163
+ Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
164
+
165
+ ```bash
166
+ # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
167
+ export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
168
+ export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
169
+ ```
170
+
171
+ ### Generate Trajectories
172
+
173
+ Create synthetic patient trajectories using the trained model:
174
+
175
+ > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
176
+
177
+ ```bash
178
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
179
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
180
+ --data_folder $MEDS_READER_DIR \
181
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
182
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
183
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
184
+ --output_dir $MEDS_TRAJECTORY_DIR \
185
+ --per_device_eval_batch_size 16 \
186
+ --num_of_trajectories_per_sample 2 \
187
+ --generation_input_length 4096 \
188
+ --generation_max_new_tokens 4096 \
189
+ --is_data_in_meds \
190
+ --att_function_type day --inpatient_att_function_type day \
191
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
192
+ --include_auxiliary_token --include_demographic_prompt \
193
+ --include_inpatient_hour_token
194
+ ```
195
+
196
+ ### Parameters Explanation
197
+
198
+ - `generation_input_length`: Controls the length of input context for generation
199
+ - `generation_max_new_tokens`: Maximum number of new tokens to generate
200
+ - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
201
+
108
202
  ## Citation
109
203
  ```
110
204
  @article{cehrgpt2024,
@@ -13,17 +13,18 @@ cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMj
13
13
  cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
14
14
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
16
- cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=ACMXiaYnR3bKD5dRleL0_siEvhL-2HAFcy5eBgvxnH4,44412
17
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=KU0WMjc2vT1zBAl7JJkOc8dgGxsL1uFDy4dDrv-RkII,25668
16
+ cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=juM5HeZScgj8w15Bl1qC83Swld4gY6avh0QkSWLqITA,45465
17
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=_QDX9NXfmQ_S3kOf3yndb3AhoEeFiSzAOv836uYW0AY,26230
18
18
  cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
19
19
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ cehrgpt/generation/cehrgpt_conditional_generation.py,sha256=AM76yaPyw1B-bcdei24HO0uspGZWHGKWpYpHywotTIQ,11972
20
21
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
21
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
22
+ cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=P8al4-zqymqEkCHCCu2sqz_45akcKF2o_AtQIjJdVmQ,11919
22
23
  cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
23
24
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
24
25
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
26
  cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
26
- cehrgpt/models/hf_cehrgpt.py,sha256=77CAkdMPgxD4xSpFU7gYGzRn6_Iv-4q7FnHpnZGsKxw,92450
27
+ cehrgpt/models/hf_cehrgpt.py,sha256=3P7bOLDr7NMSedGszhmlJJN4Mhpd_65-x6uzwvSjigE,92837
27
28
  cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
28
29
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
29
30
  cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
@@ -38,11 +39,11 @@ cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
38
39
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
39
40
  cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
40
41
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
- cehrgpt/runners/data_utils.py,sha256=I6k1TkiiZR8ggw3eVO16g2lVPY-Hu3b-nbrIOKlFIO0,15528
42
+ cehrgpt/runners/data_utils.py,sha256=i-krtBx_6rvPYtdLdDoWwOTtJcaovd0wH8gBYmgN2l4,16013
42
43
  cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
43
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=GVbHHqf5TWGbVWlQG-XurgYH8pKRjTk8ug_ib9L9U7E,28118
44
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=1OgxLm4T7iHv5pKi2QaSdaz9ogWo2n3sSUGp6cHDF9s,28309
44
45
  cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
45
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=ejAFLM9g765p1fyeF5MITsiIeWHKkz9wTeFDeVgxSto,8851
46
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=fJR4RHPqal1YI6_KUH-WlkoQLSZuBT5bKUGfPHDFrWI,9350
46
47
  cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
47
48
  cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
48
49
  cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -65,10 +66,10 @@ cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8e
65
66
  cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
66
67
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
67
68
  cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
68
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=q0rmlBWDDEkjHjwcTouGUhCYa32a1vRicaDOAMsdW0I,20741
69
+ cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=Hpx7WvAWm2WwPHFfimCADXh019I7bwdzJ4_5_YCxQzU,19817
69
70
  cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
70
- cehrgpt-0.1.1.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
71
- cehrgpt-0.1.1.dist-info/METADATA,sha256=VnXH74vJQZaV7VxGiIvJnFhQA0jzJQNx86yHFkygobM,4922
72
- cehrgpt-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
73
- cehrgpt-0.1.1.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
74
- cehrgpt-0.1.1.dist-info/RECORD,,
71
+ cehrgpt-0.1.2.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
72
+ cehrgpt-0.1.2.dist-info/METADATA,sha256=D7gGKrQThiLivViFeNm711NCP8J-wXfkueMGb6RKqV0,8481
73
+ cehrgpt-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
74
+ cehrgpt-0.1.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
75
+ cehrgpt-0.1.2.dist-info/RECORD,,